-
Notifications
You must be signed in to change notification settings - Fork 257
Implement get_unit_spike_trains and performance improvements
#4502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3dc5729
d1a0532
33c6769
2c94bac
a40d073
ef40b73
11c5812
ada53f8
22ff8fd
cbc36de
6b3e373
359b68b
85220e5
0efad83
b1911bf
0744705
c71550b
6a82577
1d4a3ce
9a139b5
832f44f
fe15764
329d220
db86f54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -180,6 +180,7 @@ def get_unit_spike_train( | |
| segment_index=segment_index, | ||
| start_time=start_time, | ||
| end_time=end_time, | ||
| use_cache=use_cache, | ||
| ) | ||
|
|
||
| segment_index = self._check_segment_index(segment_index) | ||
|
|
@@ -212,6 +213,7 @@ def get_unit_spike_train_in_seconds( | |
| segment_index: int | None = None, | ||
| start_time: float | None = None, | ||
| end_time: float | None = None, | ||
| use_cache: bool = True, | ||
| ) -> np.ndarray: | ||
| """ | ||
| Get spike train for a unit in seconds. | ||
|
|
@@ -236,6 +238,8 @@ def get_unit_spike_train_in_seconds( | |
| The start time in seconds for spike train extraction | ||
| end_time : float or None, default: None | ||
| The end time in seconds for spike train extraction | ||
| use_cache : bool, default: True | ||
| If True, precompute (or use) the reordered spike vector cache for fast access. | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -246,7 +250,7 @@ def get_unit_spike_train_in_seconds( | |
| segment = self.segments[segment_index] | ||
|
|
||
| # If sorting has a registered recording, get the frames and get the times from the recording | ||
| # Note that this take into account the segment start time of the recording | ||
| # Note that this takes into account the segment start time of the recording | ||
| if self.has_recording(): | ||
|
|
||
| # Get all the spike times and then slice them | ||
|
|
@@ -258,7 +262,7 @@ def get_unit_spike_train_in_seconds( | |
| start_frame=start_frame, | ||
| end_frame=end_frame, | ||
| return_times=False, | ||
| use_cache=True, | ||
| use_cache=use_cache, | ||
| ) | ||
|
|
||
| spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index) | ||
|
|
@@ -288,13 +292,169 @@ def get_unit_spike_train_in_seconds( | |
| start_frame=start_frame, | ||
| end_frame=end_frame, | ||
| return_times=False, | ||
| use_cache=True, | ||
| use_cache=use_cache, | ||
| ) | ||
|
|
||
| t_start = segment._t_start if segment._t_start is not None else 0 | ||
| spike_times = spike_frames / self.get_sampling_frequency() | ||
| return t_start + spike_times | ||
|
|
||
| def get_unit_spike_trains( | ||
| self, | ||
| unit_ids: np.ndarray | list, | ||
| segment_index: int | None = None, | ||
| start_frame: int | None = None, | ||
| end_frame: int | None = None, | ||
| return_times: bool = False, | ||
| use_cache: bool = True, | ||
| ) -> dict[int | str, np.ndarray]: | ||
| """Return spike trains for multiple units. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| unit_ids : np.ndarray | list | ||
| Unit ids to retrieve spike trains for | ||
| segment_index : int or None, default: None | ||
| The segment index to retrieve spike train from. | ||
| For multi-segment objects, it is required | ||
| start_frame : int or None, default: None | ||
| The start frame for spike train extraction | ||
| end_frame : int or None, default: None | ||
| The end frame for spike train extraction | ||
| return_times : bool, default: False | ||
| If True, returns spike times in seconds instead of frames | ||
| use_cache : bool, default: True | ||
| If True, precompute (or use) the reordered spike vector cache for fast access. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[int | str, np.ndarray] | ||
| A dictionary where keys are unit ids and values are spike trains (arrays of spike times or frames) | ||
| """ | ||
| if return_times: | ||
| start_time = ( | ||
| self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None | ||
| ) | ||
| end_time = ( | ||
| self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None | ||
| ) | ||
|
|
||
| return self.get_unit_spike_trains_in_seconds( | ||
| unit_ids=unit_ids, | ||
| segment_index=segment_index, | ||
| start_time=start_time, | ||
| end_time=end_time, | ||
| use_cache=use_cache, | ||
| ) | ||
|
|
||
| segment_index = self._check_segment_index(segment_index) | ||
| segment = self.segments[segment_index] | ||
| if use_cache: | ||
| # TODO: speed things up | ||
| ordered_spike_vector, slices = self.to_reordered_spike_vector( | ||
| lexsort=("sample_index", "segment_index", "unit_index"), | ||
| return_order=False, | ||
| return_slices=True, | ||
| ) | ||
| unit_indices = self.ids_to_indices(unit_ids) | ||
| spike_trains = {} | ||
| for unit_index, unit_id in zip(unit_indices, unit_ids): | ||
| sl0, sl1 = slices[unit_index, segment_index, :] | ||
| spikes = ordered_spike_vector[sl0:sl1] | ||
| spike_frames = spikes["sample_index"] | ||
| if start_frame is not None: | ||
| start = np.searchsorted(spike_frames, start_frame) | ||
| spike_frames = spike_frames[start:] | ||
| if end_frame is not None: | ||
| end = np.searchsorted(spike_frames, end_frame) | ||
| spike_frames = spike_frames[:end] | ||
| spike_trains[unit_id] = spike_frames | ||
| else: | ||
| spike_trains = segment.get_unit_spike_trains( | ||
| unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame | ||
| ) | ||
|
Comment on lines
+352
to
+375
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is overkill, and should be replaced with something like In my local testing, this gives the same speed results. The one thing gain is avoiding repeated time slicing, but I think it's a marginal gain for all this, fairly confusing, code. The spike train code is already fairly complex. Happy to be proven wrong with benchmarking from a very long recording, but I don't think it's worth doing this until we need it.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the
However, for the
|
||
| return spike_trains | ||
|
|
||
| def get_unit_spike_trains_in_seconds( | ||
| self, | ||
| unit_ids: np.ndarray | list, | ||
| segment_index: int | None = None, | ||
| start_time: float | None = None, | ||
| end_time: float | None = None, | ||
| use_cache: bool = True, | ||
| ) -> dict[int | str, np.ndarray]: | ||
| """Return spike trains for multiple units in seconds. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| unit_ids : np.ndarray | list | ||
| Unit ids to retrieve spike trains for | ||
| segment_index : int or None, default: None | ||
| The segment index to retrieve spike train from. | ||
| For multi-segment objects, it is required | ||
| start_time : float or None, default: None | ||
| The start time in seconds for spike train extraction | ||
| end_time : float or None, default: None | ||
| The end time in seconds for spike train extraction | ||
| use_cache : bool, default: True | ||
| If True, precompute (or use) the reordered spike vector cache for fast access. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[int | str, np.ndarray] | ||
| A dictionary where keys are unit ids and values are spike trains (arrays of spike times in seconds) | ||
| """ | ||
| segment_index = self._check_segment_index(segment_index) | ||
| segment = self.segments[segment_index] | ||
|
|
||
| # If sorting has a registered recording, get the frames and get the times from the recording | ||
| # Note that this takes into account the segment start time of the recording | ||
| spike_times = {} | ||
| if self.has_recording(): | ||
| # Get all the spike times and then slice them | ||
| start_frame = None | ||
| end_frame = None | ||
| spike_train_frames = self.get_unit_spike_trains( | ||
| unit_ids=unit_ids, | ||
| segment_index=segment_index, | ||
| start_frame=start_frame, | ||
| end_frame=end_frame, | ||
|
Comment on lines
+420
to
+421
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was expecting this code to use the times to figure out the start/end frames, and use them here. Instead, this code gets all spike trains then slices. Why?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also was confused by this at first, but I think it is because there's no guarantee that the sample returned by You do raise a good point, which is that it seems inefficient to scan the whole train, depending on the underlying representation, and in fact I did implement the bounded scan on start_frame = None if start_time is None else first_frame_at_or_after(start_time)
end_frame = None if end_time is None else first_frame_at_or_after(end_time)
spike_frames = self.get_unit_spike_train(..., start_frame=start_frame, end_frame=end_frame)
spike_times = self.sample_index_to_time(spike_frames, ...)
spike_times = spike_times[spike_times >= start_time]
spike_times = spike_times[spike_times < end_time]It seems plausible to me that this could save time. |
||
| return_times=False, | ||
| use_cache=use_cache, | ||
| ) | ||
|
|
||
| for unit_id in unit_ids: | ||
| spike_frames = self.sample_index_to_time(spike_train_frames[unit_id], segment_index=segment_index) | ||
|
|
||
| # Filter to return only the spikes within the specified time range | ||
| if start_time is not None: | ||
| spike_frames = spike_frames[spike_frames >= start_time] | ||
| if end_time is not None: | ||
| spike_frames = spike_frames[spike_frames <= end_time] | ||
|
|
||
| spike_times[unit_id] = spike_frames | ||
|
|
||
| return spike_times | ||
|
|
||
| # If no recording attached and all back to frame-based conversion | ||
| # Get spike train in frames and convert to times using traditional method | ||
| start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None | ||
| end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None | ||
|
|
||
| spike_frames = self.get_unit_spike_trains( | ||
| unit_ids=unit_ids, | ||
| segment_index=segment_index, | ||
| start_frame=start_frame, | ||
| end_frame=end_frame, | ||
| return_times=False, | ||
| use_cache=use_cache, | ||
| ) | ||
| for unit_id in unit_ids: | ||
| spike_frames_unit = spike_frames[unit_id] | ||
| t_start = segment._t_start if segment._t_start is not None else 0 | ||
| spike_times[unit_id] = spike_frames_unit / self.get_sampling_frequency() + t_start | ||
| return spike_times | ||
|
|
||
| def register_recording(self, recording, check_spike_frames: bool = True): | ||
| """ | ||
| Register a recording to the sorting. If the sorting and recording both contain | ||
|
|
@@ -978,7 +1138,7 @@ def to_reordered_spike_vector( | |
| s1 = seg_slices[segment_index + 1] | ||
| slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] | ||
|
|
||
| elif ("sample_index", "unit_index", "segment_index"): | ||
| elif lexsort == ("sample_index", "unit_index", "segment_index"): | ||
| slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) | ||
| seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") | ||
| for segment_index in range(self.get_num_segments()): | ||
|
|
@@ -1083,26 +1243,59 @@ def __init__(self, t_start=None): | |
|
|
||
| def get_unit_spike_train( | ||
| self, | ||
| unit_id, | ||
| unit_id: int | str, | ||
| start_frame: int | None = None, | ||
| end_frame: int | None = None, | ||
| ) -> np.ndarray: | ||
| """Get the spike train for a unit. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| unit_id | ||
| unit_id : int | str | ||
| The unit id for which to get the spike train. | ||
| start_frame : int, default: None | ||
| The start frame for the spike train. If None, it is set to the beginning of the segment. | ||
| end_frame : int, default: None | ||
| The end frame for the spike train. If None, it is set to the end of the segment. | ||
|
|
||
|
|
||
| Returns | ||
| ------- | ||
| np.ndarray | ||
|
|
||
| The spike train for the given unit id and time interval. | ||
| """ | ||
| # must be implemented in subclass | ||
| raise NotImplementedError | ||
|
|
||
| def get_unit_spike_trains( | ||
| self, | ||
| unit_ids: np.ndarray | list, | ||
| start_frame: int | None = None, | ||
| end_frame: int | None = None, | ||
| ) -> dict[int | str, np.ndarray]: | ||
| """Get the spike trains for several units. | ||
| Can be implemented in subclass for performance but the default implementation is to call | ||
| get_unit_spike_train for each unit_id. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| unit_ids : numpy.array or list | ||
| The unit ids for which to get the spike trains. | ||
| start_frame : int, default: None | ||
| The start frame for the spike trains. If None, it is set to the beginning of the segment. | ||
| end_frame : int, default: None | ||
| The end frame for the spike trains. If None, it is set to the end of the segment. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[int | str, np.ndarray] | ||
| A dictionary where keys are unit_ids and values are the corresponding spike trains. | ||
| """ | ||
| spike_trains = {} | ||
| for unit_id in unit_ids: | ||
| spike_trains[unit_id] = self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame) | ||
| return spike_trains | ||
|
|
||
|
|
||
| class SpikeVectorSortingSegment(BaseSortingSegment): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd vote to get all spike trains if user doesn't pass
unit_ids. Surely almost all user use cases forget_unit_spike_trainsis to get all unit spike trains?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm all for that. I actually think that unless we are going to cache the spike trains as such (rather than as a reordered spike vector) -- and I don't think we are [1] -- we should just call the function
get_all_spike_trains()and return them all. That would most accurately reflect what the function does. It would make it less surprising for the user that getting 30 spike trains takes the same time as getting 300 spike trains. It would probably also encourage better access patterns. And they can easily filter the dict themselves with a 1-liner like:[1] I think caching both the spike trains and the spike vector would be bad, since the caches could drift out of sync with each other unless care were taken to avoid this, and syncing caches would presumably negate all the benefits of using one representation over another.