Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions src/spikeinterface/sortingcomponents/matching/nearest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Sorting components: template matching."""

import numpy as np
from spikeinterface.core.template_tools import get_template_extremum_channel
from spikeinterface.core import get_noise_levels, get_channel_distances
from spikeinterface.core.sparsity import compute_sparsity


from .base import BaseTemplateMatching, _base_matching_dtype
Expand All @@ -23,8 +25,15 @@ class NearestTemplatesPeeler(BaseTemplateMatching):
The threshold for peak detection in term of k x MAD
noise_levels : None | array
If None the noise levels are estimated using random chunks of the recording. If array it should be an array of size (num_channels,) with the noise level of each channel
radius_um : float
The radius to define the neighborhood between channels in micrometers while detecting the peaks
detection_radius_um : float, default 100.0
The radius to define the neighborhood while detecting the peaks for locally exclusive detection.
neighborhood_radius_um : float, default 50.0
The radius to use to select neighbour templates when assigning a detected peak to a template.
The neighborhood is defined around the extremum channel of the templates.
sparsity_radius_um : float, default 100.0
The radius in um to use to compute the sparsity of the templates when the templates are not already sparse.
support_radius_um : float, default 50.0
The radius in um to use to define the support of the templates when computing the distance between templates and waveforms.
"""

def __init__(
Expand All @@ -39,6 +48,7 @@ def __init__(
detection_radius_um=100.0,
neighborhood_radius_um=50.0,
sparsity_radius_um=100.0,
support_radius_um=50.0,
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output)
Expand All @@ -48,13 +58,12 @@ def __init__(
self.peak_sign = peak_sign
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance <= detection_radius_um
self.support_radius_um = support_radius_um

num_templates = len(self.templates.unit_ids)
num_channels = recording.get_num_channels()

if neighborhood_radius_um is not None:
from spikeinterface.core.template_tools import get_template_extremum_channel

best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index")
best_channels = np.array([best_channels[i] for i in templates.unit_ids])
channel_locations = recording.get_channel_locations()
Expand All @@ -65,34 +74,35 @@ def __init__(
else:
self.neighborhood_mask = np.ones((num_channels, num_templates), dtype=bool)

if sparsity_radius_um is not None:
if support_radius_um is not None:
if not templates.are_templates_sparse():
from spikeinterface.core.sparsity import compute_sparsity

sparsity = compute_sparsity(
templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign
)
if sparsity_radius_um is not None:
sparsity = compute_sparsity(
templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign
)
else:
raise ValueError("sparsity_radius_um should be provided if templates are not sparse")
else:
sparsity = templates.sparsity

self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool)
for channel_index in np.arange(num_channels):
mask = self.neighborhood_mask[channel_index]
self.sparsity_mask[channel_index] = np.sum(sparsity.mask[mask], axis=0) > 0
channel_locations = recording.get_channel_locations()
channel_distances = np.linalg.norm(channel_locations[:, None] - channel_locations[np.newaxis, :], axis=2)
self.sparsity_support_mask = channel_distances <= self.support_radius_um
else:
self.sparsity_mask = np.ones((num_channels, num_channels), dtype=bool)

self.templates_array = self.templates.get_dense_templates()
self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0)
self.nbefore = self.templates.nbefore
self.nafter = self.templates.nafter
self.margin = max(self.nbefore, self.nafter)
self.width = self.nbefore + self.nafter
self.margin = self.width + 1 + self.exclude_sweep_size
self.lookup_tables = {}
self.lookup_tables["templates"] = {}
self.lookup_tables["channels"] = {}
for i in range(num_channels):
self.lookup_tables["templates"][i] = np.flatnonzero(self.neighborhood_mask[i])
self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_mask[i])
self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_support_mask[i])

def get_trace_margin(self):
return self.margin
Expand All @@ -104,13 +114,14 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
from scipy.spatial.distance import cdist

if self.margin > 0:
peak_traces = traces[self.margin : -self.margin, :]
peak_traces = traces[self.width : -self.width, :]
else:
peak_traces = traces

peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk(
peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask
)
peak_sample_ind += self.margin
peak_sample_ind += self.width

spikes = np.empty(peak_sample_ind.size, dtype=_base_matching_dtype)
spikes["sample_index"] = peak_sample_ind
Expand Down Expand Up @@ -164,6 +175,7 @@ def __init__(
detection_radius_um=100.0,
neighborhood_radius_um=50.0,
sparsity_radius_um=100.0,
support_radius_um=50.0,
):

NearestTemplatesPeeler.__init__(
Expand All @@ -178,6 +190,7 @@ def __init__(
detection_radius_um=detection_radius_um,
neighborhood_radius_um=neighborhood_radius_um,
sparsity_radius_um=sparsity_radius_um,
support_radius_um=support_radius_um,
)

from spikeinterface.sortingcomponents.waveforms.waveform_utils import (
Expand Down Expand Up @@ -206,13 +219,13 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
)

if self.margin > 0:
peak_traces = traces[self.margin : -self.margin, :]
peak_traces = traces[self.width : -self.width, :]
else:
peak_traces = traces
peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk(
peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask
)
peak_sample_ind += self.margin
peak_sample_ind += self.width

spikes = np.empty(peak_sample_ind.size, dtype=_base_matching_dtype)
spikes["sample_index"] = peak_sample_ind
Expand Down
Loading