Skip to content

Implement get_unit_spike_trains and performance improvements#4502

Open
alejoe91 wants to merge 24 commits intoSpikeInterface:mainfrom
alejoe91:get-unit-spike-trains
Open

Implement get_unit_spike_trains and performance improvements#4502
alejoe91 wants to merge 24 commits intoSpikeInterface:mainfrom
alejoe91:get-unit-spike-trains

Conversation

@alejoe91
Copy link
Copy Markdown
Member

@alejoe91 alejoe91 commented Apr 9, 2026

  • expose and propagate use_cache (to get_unit_spike_train_in_seconds)
  • fix wrong check in to_reordered_spike_vector
  • avoid lexsort when not needed in select_units

@grahamfindlay

TODO

  • Implement numpy/numba get_unit_spike_trains for PhyKilosortSortingExtractor

(maybe in follow up)

@alejoe91 alejoe91 requested a review from chrishalcrow April 9, 2026 15:41
@alejoe91 alejoe91 added core Changes to core module performance Performance issues/improvements labels Apr 9, 2026
alejoe91 and others added 8 commits April 9, 2026 17:58
- Drop unused `return_times` parameter from get_unit_spike_trains_in_seconds
- Clean up stale/truncated docstrings on get_unit_spike_train_in_seconds,
  get_unit_spike_trains, and get_unit_spike_trains_in_seconds
- Fix UnitsSelectionSortingSegment.get_unit_spike_trains to re-key the
  returned dict with child unit ids (was returning parent-keyed dict,
  breaking whenever renamed_unit_ids differ from parent ids)
- Fix test_get_unit_spike_trains: drop unused return_times kwarg, remove
  unused local variable, fix assertion.
The previous check `np.diff(self.ids_to_indices(self._renamed_unit_ids)).min() < 0`
was never `True`, because `ids_to_indices(self._renamed_unit_ids)` on a USS
always returns `[0, 1, ..., k-1]` (since `_main_ids == _renamed_unit_ids`), so the
diff was always positive and the lexsort branch was unreachable. Therefore the
cached spike vector was wrong whenever two units had co-temporal spikes and the
selection reordered them relative to the parent.

Replaced with a two-step check that attempt to avoid unneccessary lexsorts:
  1. O(k) `_is_order_preserving_selection()` -- Checks if USS `._unit_ids` is
     in the same relative order as in the parent. When True, the remapped vector
     is guaranteed sorted (boolean filtering preserves order; the remap only
     relabels unit_index values). This is the common case via `select_units()`
     with a boolean mask.
  2. O(n) `_is_spike_vector_sorted()` -- Checks if the remapped vector is still
     sorted by (segment, sample, unit). Catches the case where the selection is
     not order-preserving but no co-temporal (same exact sample) spikes exist.

Falls back to the original O(n log n) lexsort only when both checks fail.
`BaseSorting` builds the spike vector with a per-unit boolean scan
over spike_clusters, which is (O(N*K)).

If we already have the full flat spike time and spike cluster arrays, we can
do a lot better by building the spike vector in one shot.
(I think O(N log N) from the lexsort, which is also pessimistic,
because the lexsort doesn't always need to happen.
Under any circumstances I can dream of, K >> log N.)

Since Phy/Kilosort segments already load the full flat arrays when the
`PhyKilosortSorting` object is created, and keep them around  as
`._all_spikes` and `._all_clusters`, we can just use those! :)

Also populates `_cached_spike_vector_segment_slices` directly, so
that `BaseSorting`'s `_get_spike_vector_segment_slices()` lazy
recomputation is skipped.
`BaseSortingSegment.get_unit_spike_trains()` loops over
`get_unit_spike_train`, which is O(N*K) because each call is a
boolean scan over _all_clusters/_all_spikes.

If we know we are going to be getting all the trains, we can do it
much faster. And if we can use numba, even faster still.

In fact, even if we only want _some_ spike trains, it is still often
faster to get all the trains and just discard the ones we don't need,
than to get only the trains we need do unit-by-unit (because we
only ever store or cache flat arrays of spike times/clusters).

Note that **only the use_cache=False path is affected**; the
use_cache=True triggers the computation of the spike vector, which
I don't think can ever be the most efficient way to get spike trains.
…izations

- Fixed test_compute_and_cache_spike_vector: was comparing an array to
  itself (to_spike_vector use_cache=False still returns the cached
  vector). Now explicitly calls the USS override and the BaseSorting
  implementation, and compares the two.
- Added test_uss_get_unit_spike_trains_with_renamed_ids: also not a test
  of the optimization commits per se, but would have caught a mistake made
  along the way. Verifies get_unit_spike_trains returns child-keyed dicts
  (not parent-keyed).
- Added test_spike_vector_sorted_after_reorder_with_cotemporal_spikes:
  verifies the USS spike vector is correctly sorted when the selection
  reverses unit order and co-temporal spikes exist.
- Added test_phy_sorting_segment_get_unit_spike_trains: validates the
  new fast methods on PhySortingSegment.
- Added test_phy_compute_and_cache_spike_vector: verifies the Phy
  override of _compute_and_cache_spike_vector matches BaseSorting
  implementation.
@grahamfindlay
Copy link
Copy Markdown
Contributor

@alejoe91 my changes PR'd to your fork whenever you're ready.

The only thing I should point out that isn't in the commit messages:
I mocked a minimal Phy folder for testing instead of using the phy_example_0 GIN dataset, just because it was quick, easy, and lightweight. I did feel a little guilty doing it, but I'm also not convinced it was a bad idea.

@alejoe91 alejoe91 marked this pull request as ready for review April 14, 2026 09:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

core Changes to core module performance Performance issues/improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants