Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
TimeSeriesChunkExecutor,
split_job_kwargs,
fix_job_kwargs,
get_inner_pool,
thread_budget,
)
from .recording_tools import (
write_binary_recording,
Expand Down
106 changes: 106 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,92 @@ def get_traces(
traces = traces.astype("float32", copy=False) * gains + offsets
return traces

def get_traces_multi_thread(
self,
segment_index: int | None = None,
start_frame: int | None = None,
end_frame: int | None = None,
channel_ids: list | np.ndarray | tuple | None = None,
order: Literal["C", "F"] | None = None,
return_in_uV: bool = False,
max_threads: int | None = None,
) -> np.ndarray:
"""Like ``get_traces``, but the segment kernel may use up to
``max_threads`` threads internally to compute its output.

Most segments fall through to the serial ``get_traces`` path; only
segments whose kernels benefit from intra-call parallelism (e.g.
``FilterRecordingSegment``, ``CommonReferenceRecordingSegment``)
override ``BaseRecordingSegment.get_traces_multi_thread`` to actually
use the budget.

Parameters
----------
max_threads : int or None, default: None
Inner thread budget for this single call. ``None`` means
"look up ``max_threads_per_worker`` from the global job_kwargs."
``<= 1`` falls back to plain ``get_traces``.

.. note::
The implicit ``None`` lookup is only safe in the **parent
process**. Inside a ``TimeSeriesChunkExecutor`` worker
(especially with ``mp_context="spawn"`` / ``"forkserver"`` or on
macOS / Windows defaults), the worker's globals do not reflect
the parent's ``set_global_job_kwargs(...)``. Chunk callbacks
that want intra-call parallelism inside CRE must pass
``max_threads`` explicitly.

See ``get_traces`` for the other parameters.
"""
if max_threads is None:
from .globals import get_global_job_kwargs

max_threads = int(get_global_job_kwargs().get("max_threads_per_worker", 1) or 1)

if max_threads <= 1:
return self.get_traces(
segment_index=segment_index,
start_frame=start_frame,
end_frame=end_frame,
channel_ids=channel_ids,
order=order,
return_in_uV=return_in_uV,
)

segment_index = self._check_segment_index(segment_index)
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
rs = self.segments[segment_index]
start_frame = int(start_frame) if start_frame is not None else 0
num_samples = rs.get_num_samples()
end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples
traces = rs.get_traces_multi_thread(
start_frame=start_frame,
end_frame=end_frame,
channel_indices=channel_indices,
max_threads=max_threads,
)

if order is not None:
assert order in ["C", "F"]
traces = np.asanyarray(traces, order=order)

if return_in_uV:
if not self.has_scaleable_traces():
if self._dtype.kind == "f":
pass
else:
raise ValueError(
"This recording does not support return_in_uV=True (need gain_to_uV and offset_"
"to_uV properties)"
)
else:
gains = self.get_property("gain_to_uV")
offsets = self.get_property("offset_to_uV")
gains = gains[channel_indices].astype("float32", copy=False)
offsets = offsets[channel_indices].astype("float32", copy=False)
traces = traces.astype("float32", copy=False) * gains + offsets
return traces

def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
"""
General retrieval function for time_series objects
Expand Down Expand Up @@ -673,6 +759,26 @@ def get_traces(
# must be implemented in subclass
raise NotImplementedError

def get_traces_multi_thread(
self,
start_frame: int | None = None,
end_frame: int | None = None,
channel_indices: list | np.ndarray | tuple | None = None,
max_threads: int = 1,
) -> np.ndarray:
"""Default: serial fall-through to ``get_traces``.

Override on segments whose kernels benefit from intra-call
parallelism (channel-block fan-out, time-block fan-out, numba
prange). See ``core/job_tools.py:get_inner_pool`` and
``thread_budget`` for the building blocks.
"""
return self.get_traces(
start_frame=start_frame,
end_frame=end_frame,
channel_indices=channel_indices,
)

def get_data(
self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None
) -> np.ndarray:
Expand Down
136 changes: 136 additions & 0 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from tqdm.auto import tqdm

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from contextlib import ExitStack, contextmanager
import multiprocessing
import threading
import weakref
from threadpoolctl import threadpool_limits

from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str
Expand Down Expand Up @@ -759,3 +761,137 @@ def get_poolexecutor(n_jobs):
return MockPoolExecutor
else:
return ProcessPoolExecutor


# ---------------------------------------------------------------------------
# Intra-call thread fan-out utilities (used by ``get_traces_multi_thread``)
#
# These let a single ``get_traces`` call internally spend a thread budget
# (``max_threads_per_worker`` from job_kwargs) without exposing per-class
# init kwargs. Each segment that benefits from intra-call parallelism
# overrides ``BaseRecordingSegment.get_traces_multi_thread`` and picks
# the mechanism it actually needs:
#
# - explicit Python-thread fan-out → ``get_inner_pool``
# - BLAS / OpenMP cap (matmuls) → ``thread_budget(blas=True)``
# - numba ``prange`` parallelism → ``thread_budget(numba=True)``
#
# All three compose, but most segments use only one.

# Module-global per-caller-thread pool registry. Keyed by
# ``Thread → {max_threads → ThreadPoolExecutor}`` so that the same calling
# thread reusing the same budget gets the same pool across calls and across
# segments (a chained pipeline reuses one pool per (Thread, max_threads)
# pair, not one per segment).
#
# Identity-stable: never re-bound, only ``.clear()``ed in the post-fork
# guard, so callers that imported ``_inner_pools`` keep a valid reference.
_inner_pools: "weakref.WeakKeyDictionary[threading.Thread, dict]" = weakref.WeakKeyDictionary()
_inner_pools_lock = threading.Lock()
_inner_pools_pid: int = os.getpid()


def _shutdown_inner_pools(sized_dict):
"""Finalizer for a thread's pool dict: shut down all its pools.

``wait=False`` to avoid blocking the finalizer thread. In-flight tasks
would be cancelled, but the owning thread submits + joins synchronously,
so no such tasks exist when it actually exits.
"""
for pool in sized_dict.values():
pool.shutdown(wait=False)


def get_inner_pool(max_threads: int) -> ThreadPoolExecutor | None:
"""Per-caller-thread ``ThreadPoolExecutor`` of size ``max_threads``.

Same calling thread + same ``max_threads`` returns the same pool —
across calls, across segments. Different calling threads get distinct
pools so concurrent outer workers never queue on a shared inner pool
(the pathology that otherwise dominates when CRE ``n_jobs`` exceeds the
inner pool size).

Returns ``None`` for ``max_threads <= 1`` so callers can keep a single
serial-fallback branch.

Pools are owned by the calling ``Thread`` (via ``WeakKeyDictionary``),
so when the thread is garbage-collected its pools are shut down
automatically.

A pid guard clears the registry after ``os.fork()``: in a forked child
the parent's ``ThreadPoolExecutor``s reference Thread objects whose OS
threads were not copied across, so submitting to them would deadlock.
Pickled (spawn / forkserver) workers come up with their own module-load
state and never see this.
"""
if max_threads <= 1:
return None

global _inner_pools_pid
pid = os.getpid()
if _inner_pools_pid != pid:
with _inner_pools_lock:
if _inner_pools_pid != pid:
_inner_pools.clear()
_inner_pools_pid = pid

thread = threading.current_thread()
sized = _inner_pools.get(thread)
if sized is None:
with _inner_pools_lock:
sized = _inner_pools.get(thread)
if sized is None:
sized = {}
_inner_pools[thread] = sized
weakref.finalize(thread, _shutdown_inner_pools, sized)
pool = sized.get(max_threads)
if pool is None:
with _inner_pools_lock:
pool = sized.get(max_threads)
if pool is None:
pool = ThreadPoolExecutor(max_workers=max_threads)
sized[max_threads] = pool
return pool


@contextmanager
def thread_budget(max_threads: int, *, blas: bool = False, numba: bool = False):
"""Cap underlying thread runtimes for the duration of the context.

Caller picks which mechanisms apply — the rest are left alone. Compose
with ``get_inner_pool`` for explicit Python-thread fan-out (a separate
mechanism that doesn't need a context manager).

Parameters
----------
max_threads : int
Per-mechanism thread cap. ``<= 1`` is a no-op (still enters the
context but caps to 1, which is what ``threadpool_limits`` /
``numba.set_num_threads`` do anyway).
blas : bool, default False
Apply ``threadpool_limits(limits=max_threads)`` — caps the C-level
thread pools used by BLAS (OpenBLAS / MKL / BLIS) and OpenMP
(libgomp / libomp).
numba : bool, default False
Apply ``numba.set_num_threads(max_threads)`` for the duration of the
scope. Restored on exit. Only meaningful for ``@njit(parallel=True)``
kernels using ``prange``; harmless otherwise.

Notes
-----
threadpoolctl can sometimes reach numba's threading layer (when numba
is configured to use OpenMP), but this is unreliable across
``NUMBA_THREADING_LAYER`` choices. Use ``numba=True`` explicitly when
a segment actually contains a numba parallel kernel — don't rely on
``blas=True`` to reach it.
"""
with ExitStack() as stack:
if blas:
stack.enter_context(threadpool_limits(limits=max_threads))
if numba:
import numba as _nb

prev = _nb.get_num_threads()
_nb.set_num_threads(max(1, max_threads))
stack.callback(_nb.set_num_threads, prev)
yield
30 changes: 24 additions & 6 deletions src/spikeinterface/core/time_series_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def get_chunk_with_margin(
add_reflect_padding=False,
window_on_margin=False,
dtype=None,
max_threads: int = 1,
):
"""
Helper to get chunk with margin
Expand All @@ -586,12 +587,33 @@ def get_chunk_with_margin(
of `add_zeros` or `add_reflect_padding` is True. In the first
case zero padding is used, in the second case np.pad is called
with mod="reflect".

When ``max_threads > 1`` and the segment is a recording segment with a
``get_traces_multi_thread`` override, the upstream fetch goes through
that parallel kernel so a chained pipeline (e.g. Filter → CMR) gets
end-to-end parallelism per call. Snippets and other generic
``TimeSeriesSegment`` subtypes always use ``get_data`` (serial).
"""
length = int(chunkable_segment.get_num_samples())

if last_dimension_indices is None:
last_dimension_indices = slice(None)

# Local fetcher: branch on max_threads + recording-segment capability.
# Keeps ``get_data`` as a clean generic-TimeSeries API and pushes the
# "parallel if K>1" decision to the one call site that cares.
use_multi = max_threads > 1 and hasattr(chunkable_segment, "get_traces_multi_thread")

def _fetch(s0, s1):
if use_multi:
return chunkable_segment.get_traces_multi_thread(
start_frame=s0,
end_frame=s1,
channel_indices=last_dimension_indices,
max_threads=max_threads,
)
return chunkable_segment.get_data(s0, s1, last_dimension_indices)

if not (add_zeros or add_reflect_padding):
if window_on_margin and not add_zeros:
raise ValueError("window_on_margin requires add_zeros=True")
Expand All @@ -612,11 +634,7 @@ def get_chunk_with_margin(
else:
right_margin = margin

data_chunk = chunkable_segment.get_data(
start_frame - left_margin,
end_frame + right_margin,
last_dimension_indices,
)
data_chunk = _fetch(start_frame - left_margin, end_frame + right_margin)

else:
# either add_zeros or reflect_padding
Expand All @@ -642,7 +660,7 @@ def get_chunk_with_margin(
end_frame2 = end_frame + margin
right_pad = 0

data_chunk = chunkable_segment.get_data(start_frame2, end_frame2, last_dimension_indices)
data_chunk = _fetch(start_frame2, end_frame2)

if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0:
need_copy = True
Expand Down
Loading
Loading