Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3dc5729
Test IBL extractors tests failing for PI update
alejoe91 Dec 29, 2025
d1a0532
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 6, 2026
33c6769
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 16, 2026
2c94bac
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 20, 2026
a40d073
Merge branch 'main' of github.com:alejoe91/spikeinterface
alejoe91 Feb 24, 2026
ef40b73
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 17, 2026
11c5812
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 24, 2026
ada53f8
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 24, 2026
22ff8fd
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 25, 2026
cbc36de
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 31, 2026
6b3e373
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Apr 9, 2026
359b68b
Implement get_unit_spike_trains function
alejoe91 Apr 9, 2026
85220e5
oups
alejoe91 Apr 9, 2026
0efad83
Fix tests
alejoe91 Apr 10, 2026
b1911bf
add tests and fixes
alejoe91 Apr 10, 2026
0744705
Fix bugs in get_unit_spike_trains_in_seconds and segment keying
grahamfindlay Apr 10, 2026
c71550b
Fix lexsort avoidance check in UnitsSelectionSorting (USS)
grahamfindlay Apr 10, 2026
6a82577
Override _compute_and_cache_spike_vector in Phy/Kilosort extractors
grahamfindlay Apr 10, 2026
1d4a3ce
Optimize get_unit_spike_trains on PhySortingSegment
grahamfindlay Apr 10, 2026
9a139b5
Add tests for UnitSelectionSorting & Phy spike vector and train optim…
grahamfindlay Apr 13, 2026
832f44f
Move is_spike_vector_sorted to sorting_tools
alejoe91 Apr 14, 2026
fe15764
Merge pull request #28 from grahamfindlay/pr4502-graham
alejoe91 Apr 14, 2026
329d220
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2026
db86f54
Merge branch 'main' into get-unit-spike-trains
alejoe91 Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 200 additions & 7 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def get_unit_spike_train(
segment_index=segment_index,
start_time=start_time,
end_time=end_time,
use_cache=use_cache,
)

segment_index = self._check_segment_index(segment_index)
Expand Down Expand Up @@ -212,6 +213,7 @@ def get_unit_spike_train_in_seconds(
segment_index: int | None = None,
start_time: float | None = None,
end_time: float | None = None,
use_cache: bool = True,
) -> np.ndarray:
"""
Get spike train for a unit in seconds.
Expand All @@ -236,6 +238,8 @@ def get_unit_spike_train_in_seconds(
The start time in seconds for spike train extraction
end_time : float or None, default: None
The end time in seconds for spike train extraction
use_cache : bool, default: True
If True, precompute (or use) the reordered spike vector cache for fast access.

Returns
-------
Expand All @@ -246,7 +250,7 @@ def get_unit_spike_train_in_seconds(
segment = self.segments[segment_index]

# If sorting has a registered recording, get the frames and get the times from the recording
# Note that this take into account the segment start time of the recording
# Note that this takes into account the segment start time of the recording
if self.has_recording():

# Get all the spike times and then slice them
Expand All @@ -258,7 +262,7 @@ def get_unit_spike_train_in_seconds(
start_frame=start_frame,
end_frame=end_frame,
return_times=False,
use_cache=True,
use_cache=use_cache,
)

spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index)
Expand Down Expand Up @@ -288,13 +292,169 @@ def get_unit_spike_train_in_seconds(
start_frame=start_frame,
end_frame=end_frame,
return_times=False,
use_cache=True,
use_cache=use_cache,
)

t_start = segment._t_start if segment._t_start is not None else 0
spike_times = spike_frames / self.get_sampling_frequency()
return t_start + spike_times

def get_unit_spike_trains(
self,
unit_ids: np.ndarray | list,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd vote to get all spike trains if user doesn't pass unit_ids. Surely almost all user use cases for get_unit_spike_trains is to get all unit spike trains?

Copy link
Copy Markdown
Contributor

@grahamfindlay grahamfindlay Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm all for that. I actually think that unless we are going to cache the spike trains as such (rather than as a reordered spike vector) -- and I don't think we are [1] -- we should just call the function get_all_spike_trains() and return them all. That would most accurately reflect what the function does. It would make it less surprising for the user that getting 30 spike trains takes the same time as getting 300 spike trains. It would probably also encourage better access patterns. And they can easily filter the dict themselves with a 1-liner like:

spike_trains = {id: train for id, train in sorting.get_all_spike_trains() if id in unit_ids}

[1] I think caching both the spike trains and the spike vector would be bad, since the caches could drift out of sync with each other unless care were taken to avoid this, and syncing caches would presumably negate all the benefits of using one representation over another.

segment_index: int | None = None,
start_frame: int | None = None,
end_frame: int | None = None,
return_times: bool = False,
use_cache: bool = True,
) -> dict[int | str, np.ndarray]:
"""Return spike trains for multiple units.

Parameters
----------
unit_ids : np.ndarray | list
Unit ids to retrieve spike trains for
segment_index : int or None, default: None
The segment index to retrieve spike train from.
For multi-segment objects, it is required
start_frame : int or None, default: None
The start frame for spike train extraction
end_frame : int or None, default: None
The end frame for spike train extraction
return_times : bool, default: False
If True, returns spike times in seconds instead of frames
use_cache : bool, default: True
If True, precompute (or use) the reordered spike vector cache for fast access.

Returns
-------
dict[int | str, np.ndarray]
A dictionary where keys are unit ids and values are spike trains (arrays of spike times or frames)
"""
if return_times:
start_time = (
self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None
)
end_time = (
self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None
)

return self.get_unit_spike_trains_in_seconds(
unit_ids=unit_ids,
segment_index=segment_index,
start_time=start_time,
end_time=end_time,
use_cache=use_cache,
)

segment_index = self._check_segment_index(segment_index)
segment = self.segments[segment_index]
if use_cache:
# TODO: speed things up
ordered_spike_vector, slices = self.to_reordered_spike_vector(
lexsort=("sample_index", "segment_index", "unit_index"),
return_order=False,
return_slices=True,
)
unit_indices = self.ids_to_indices(unit_ids)
spike_trains = {}
for unit_index, unit_id in zip(unit_indices, unit_ids):
sl0, sl1 = slices[unit_index, segment_index, :]
spikes = ordered_spike_vector[sl0:sl1]
spike_frames = spikes["sample_index"]
if start_frame is not None:
start = np.searchsorted(spike_frames, start_frame)
spike_frames = spike_frames[start:]
if end_frame is not None:
end = np.searchsorted(spike_frames, end_frame)
spike_frames = spike_frames[:end]
spike_trains[unit_id] = spike_frames
else:
spike_trains = segment.get_unit_spike_trains(
unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame
)
Comment on lines +352 to +375
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is overkill, and should be replaced with something like

spike_trains = {unit_id: self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame, use_cache=use_cache) for unit_id in unit_ids}
return spike_trains

In my local testing, this gives the same speed results. The one thing gain is avoiding repeated time slicing, but I think it's a marginal gain for all this, fairly confusing, code. The spike train code is already fairly complex.

Happy to be proven wrong with benchmarking from a very long recording, but I don't think it's worth doing this until we need it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the use_cache=True path:

  • I see basically no difference in performance for a cold cache (because lexsort dominates)
  • I see ~10-20x benefit for a hot cache (2M spikes, 800 units), but it's only ~ 10ms vs 150ms, so negligible.
    I'm not surprised the gains are marginal - it's just repeated per-unit segment checks, id_to_index(), etc that is saved. So adding more spikes won't really change much. This path still relies on the spike vector, and that will always be slow (see the precompute_spike_times() entry in the table above).

However, for the use_cache=False path:

  • I think it is important that the use_cache=False path continue to dispatch to segment.get_unit_spike_trains() (emphasis on segment not self and trains not train), so that the segment can override the multi-unit path - that is where the biggest (eg the numba accelerated) gains in this PR (seconds vs minutes) come from.

return spike_trains

def get_unit_spike_trains_in_seconds(
self,
unit_ids: np.ndarray | list,
segment_index: int | None = None,
start_time: float | None = None,
end_time: float | None = None,
use_cache: bool = True,
) -> dict[int | str, np.ndarray]:
"""Return spike trains for multiple units in seconds.

Parameters
----------
unit_ids : np.ndarray | list
Unit ids to retrieve spike trains for
segment_index : int or None, default: None
The segment index to retrieve spike train from.
For multi-segment objects, it is required
start_time : float or None, default: None
The start time in seconds for spike train extraction
end_time : float or None, default: None
The end time in seconds for spike train extraction
use_cache : bool, default: True
If True, precompute (or use) the reordered spike vector cache for fast access.

Returns
-------
dict[int | str, np.ndarray]
A dictionary where keys are unit ids and values are spike trains (arrays of spike times in seconds)
"""
segment_index = self._check_segment_index(segment_index)
segment = self.segments[segment_index]

# If sorting has a registered recording, get the frames and get the times from the recording
# Note that this takes into account the segment start time of the recording
spike_times = {}
if self.has_recording():
# Get all the spike times and then slice them
start_frame = None
end_frame = None
spike_train_frames = self.get_unit_spike_trains(
unit_ids=unit_ids,
segment_index=segment_index,
start_frame=start_frame,
end_frame=end_frame,
Comment on lines +420 to +421
Copy link
Copy Markdown
Member

@chrishalcrow chrishalcrow Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting this code to use the times to figure out the start/end frames, and use them here. Instead, this code gets all spike trains then slices. Why?
(EDIT: I'm sure there is a good reason I've not thought of!!)

Copy link
Copy Markdown
Contributor

@grahamfindlay grahamfindlay Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also was confused by this at first, but I think it is because there's no guarantee that the sample returned by BaseRecording.time_to_sample_index() exactly corresponds to the time you give it (it is more like "last frame at or before") so it can behave weirdly when you use to get fetch bounds. For example, if a time vector has samples at [0.0, 0.1, 0.2] and you pass start_time=0.15 to get_unit_spike_trains_in_seconds(), time_to_sample_index(0.15) returns frame 1, but frame 1 has time 0.1 and should be excluded. @alejoe91 can confirm.

You do raise a good point, which is that it seems inefficient to scan the whole train, depending on the underlying representation, and in fact I did implement the bounded scan on PhyKilosortSortingExtractor.get_unit_spike_trains(). Maybe what could be done is, get some conservative frame bounds, use those to fetch the underlying trains, and then do a final mask on the result. Something like:

start_frame = None if start_time is None else first_frame_at_or_after(start_time)
end_frame = None if end_time is None else first_frame_at_or_after(end_time)

spike_frames = self.get_unit_spike_train(..., start_frame=start_frame, end_frame=end_frame)
spike_times = self.sample_index_to_time(spike_frames, ...)
spike_times = spike_times[spike_times >= start_time]
spike_times = spike_times[spike_times < end_time]

It seems plausible to me that this could save time.

return_times=False,
use_cache=use_cache,
)

for unit_id in unit_ids:
spike_frames = self.sample_index_to_time(spike_train_frames[unit_id], segment_index=segment_index)

# Filter to return only the spikes within the specified time range
if start_time is not None:
spike_frames = spike_frames[spike_frames >= start_time]
if end_time is not None:
spike_frames = spike_frames[spike_frames <= end_time]

spike_times[unit_id] = spike_frames

return spike_times

# If no recording attached and all back to frame-based conversion
# Get spike train in frames and convert to times using traditional method
start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None
end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None

spike_frames = self.get_unit_spike_trains(
unit_ids=unit_ids,
segment_index=segment_index,
start_frame=start_frame,
end_frame=end_frame,
return_times=False,
use_cache=use_cache,
)
for unit_id in unit_ids:
spike_frames_unit = spike_frames[unit_id]
t_start = segment._t_start if segment._t_start is not None else 0
spike_times[unit_id] = spike_frames_unit / self.get_sampling_frequency() + t_start
return spike_times

def register_recording(self, recording, check_spike_frames: bool = True):
"""
Register a recording to the sorting. If the sorting and recording both contain
Expand Down Expand Up @@ -978,7 +1138,7 @@ def to_reordered_spike_vector(
s1 = seg_slices[segment_index + 1]
slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1]

elif ("sample_index", "unit_index", "segment_index"):
elif lexsort == ("sample_index", "unit_index", "segment_index"):
slices = np.zeros((num_segments, num_units, 2), dtype=np.int64)
seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left")
for segment_index in range(self.get_num_segments()):
Expand Down Expand Up @@ -1083,26 +1243,59 @@ def __init__(self, t_start=None):

def get_unit_spike_train(
self,
unit_id,
unit_id: int | str,
start_frame: int | None = None,
end_frame: int | None = None,
) -> np.ndarray:
"""Get the spike train for a unit.

Parameters
----------
unit_id
unit_id : int | str
The unit id for which to get the spike train.
start_frame : int, default: None
The start frame for the spike train. If None, it is set to the beginning of the segment.
end_frame : int, default: None
The end frame for the spike train. If None, it is set to the end of the segment.


Returns
-------
np.ndarray

The spike train for the given unit id and time interval.
"""
# must be implemented in subclass
raise NotImplementedError

def get_unit_spike_trains(
self,
unit_ids: np.ndarray | list,
start_frame: int | None = None,
end_frame: int | None = None,
) -> dict[int | str, np.ndarray]:
"""Get the spike trains for several units.
Can be implemented in subclass for performance but the default implementation is to call
get_unit_spike_train for each unit_id.

Parameters
----------
unit_ids : numpy.array or list
The unit ids for which to get the spike trains.
start_frame : int, default: None
The start frame for the spike trains. If None, it is set to the beginning of the segment.
end_frame : int, default: None
The end frame for the spike trains. If None, it is set to the end of the segment.

Returns
-------
dict[int | str, np.ndarray]
A dictionary where keys are unit_ids and values are the corresponding spike trains.
"""
spike_trains = {}
for unit_id in unit_ids:
spike_trains[unit_id] = self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame)
return spike_trains


class SpikeVectorSortingSegment(BaseSortingSegment):
"""
Expand Down
25 changes: 25 additions & 0 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,3 +996,28 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee
new_vector["unit_index"] = mapping[new_vector["unit_index"]]

return new_vector, keep_mask_vector


def is_spike_vector_sorted(spike_vector: np.ndarray) -> bool:
"""Return True iff the spike vector is sorted by (segment_index, sample_index, unit_index).

O(n) sequential scan. Used to avoid an O(n log n) lexsort when the vector already
happens to be in canonical order.
"""
n = len(spike_vector)
if n <= 1:
return True
seg = spike_vector["segment_index"]
samp = spike_vector["sample_index"]
unit = spike_vector["unit_index"]
d_seg = np.diff(seg)
if np.any(d_seg < 0):
return False
seg_eq = d_seg == 0
d_samp = np.diff(samp)
if np.any(d_samp[seg_eq] < 0):
return False
samp_eq = seg_eq & (d_samp == 0)
if np.any(np.diff(unit)[samp_eq] < 0):
return False
return True
25 changes: 25 additions & 0 deletions src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,31 @@ def test_select_periods():
np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector())


@pytest.mark.parametrize("use_cache", [False, True])
def test_get_unit_spike_trains(use_cache):
sampling_frequency = 10_000.0
duration = 1.0
num_units = 10
sorting = generate_sorting(durations=[duration], sampling_frequency=sampling_frequency, num_units=num_units)

all_spike_trains = sorting.get_unit_spike_trains(unit_ids=sorting.unit_ids, use_cache=use_cache)
assert isinstance(all_spike_trains, dict)
assert set(all_spike_trains.keys()) == set(sorting.unit_ids)
for unit_id in sorting.unit_ids:
spiketrain = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id, use_cache=use_cache)
assert np.array_equal(all_spike_trains[unit_id], spiketrain)

# test with times
spike_trains_times = sorting.get_unit_spike_trains_in_seconds(unit_ids=sorting.unit_ids, use_cache=use_cache)
assert isinstance(spike_trains_times, dict)
assert set(spike_trains_times.keys()) == set(sorting.unit_ids)
for unit_id in sorting.unit_ids:
spiketrain_times = sorting.get_unit_spike_train_in_seconds(
segment_index=0, unit_id=unit_id, use_cache=use_cache
)
assert np.allclose(spike_trains_times[unit_id], spiketrain_times)


if __name__ == "__main__":
import tempfile

Expand Down
Loading
Loading