Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ spikeinterface.curation
.. autofunction:: remove_redundant_units
.. autofunction:: remove_duplicated_spikes
.. autofunction:: remove_excess_spikes
.. autofunction:: threshold_metrics_label_units
.. autofunction:: model_based_label_units
.. autofunction:: load_model
.. autofunction:: train_model
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/curation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .sortingview_curation import apply_sortingview_curation

# automated curation
from .threshold_metrics_curation import threshold_metrics_label_units
from .model_based_curation import model_based_label_units, load_model, auto_label_units
from .train_manual_curation import train_model, get_default_classifier_search_spaces
from .unitrefine_curation import unitrefine_label_units
9 changes: 9 additions & 0 deletions src/spikeinterface/curation/curation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
_methods_numpy = ("keep_first", "random", "keep_last")


def is_threshold_disabled(value):
"""Check if a threshold value is disabled (None or np.nan)."""
if value is None:
return True
if isinstance(value, float) and np.isnan(value):
return True
return False


def _find_duplicated_spikes_numpy(
spike_train: np.ndarray,
censored_period: int,
Expand Down
221 changes: 221 additions & 0 deletions src/spikeinterface/curation/tests/test_threshold_metrics_curation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import pytest
import json

import numpy as np

from spikeinterface.curation import threshold_metrics_label_units


def test_threshold_metrics_label_units_with_dataframe():
import pandas as pd

metrics = pd.DataFrame(
{
"snr": [6.0, 4.0, 5.0],
"firing_rate": [0.5, 0.2, 25.0],
},
index=[0, 1, 2],
)
thresholds = {
"snr": {"min": 5.0},
"firing_rate": {"min": 0.1, "max": 20.0},
}

labels = threshold_metrics_label_units(metrics, thresholds)

assert "label" in labels.columns
assert labels.shape[0] == len(metrics.index)
assert labels["label"].to_dict() == {0: "good", 1: "noise", 2: "noise"}


def test_threshold_metrics_label_units_with_file(tmp_path):
import pandas as pd

metrics = pd.DataFrame(
{
"snr": [6.0, 4.0],
"firing_rate": [0.5, 0.05],
},
index=[0, 1],
)
thresholds = {
"snr": {"min": 5.0},
"firing_rate": {"min": 0.1},
}

thresholds_file = tmp_path / "thresholds.json"
with open(thresholds_file, "w") as f:
json.dump(thresholds, f)

labels = threshold_metrics_label_units(metrics, thresholds_file)

assert labels["label"].to_dict() == {0: "good", 1: "noise"}


def test_threshold_metrics_label_external_labels():
import pandas as pd

metrics = pd.DataFrame(
{
"snr": [6.0, 4.0],
"firing_rate": [0.5, 0.05],
},
index=[0, 1],
)
thresholds = {
"snr": {"min": 5.0},
"firing_rate": {"min": 0.1},
}

labels = threshold_metrics_label_units(
metrics,
thresholds=thresholds,
pass_label="accepted",
fail_label="rejected",
)
assert set(labels["label"]).issubset({"accepted", "rejected"})


def test_threshold_metrics_label_units_operator_or_with_dataframe():
import pandas as pd

metrics = pd.DataFrame(
{
"m1": [1.0, 1.0, -1.0, -1.0],
"m2": [1.0, -1.0, 1.0, -1.0],
},
index=[0, 1, 2, 3],
)
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}

labels_and = threshold_metrics_label_units(
metrics,
thresholds=thresholds,
operator="and",
)
assert labels_and.index.equals(metrics.index)
assert labels_and["label"].to_dict() == {0: "good", 1: "noise", 2: "noise", 3: "noise"}

labels_or = threshold_metrics_label_units(
metrics,
thresholds=thresholds,
operator="or",
)
assert labels_or.index.equals(metrics.index)
assert labels_or["label"].to_dict() == {0: "good", 1: "good", 2: "good", 3: "noise"}


def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and():
import pandas as pd

metrics = pd.DataFrame(
{
"m1": [np.nan, 1.0, np.nan],
"m2": [1.0, -1.0, -1.0],
},
index=[10, 11, 12],
)
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}

labels_fail = threshold_metrics_label_units(
metrics,
thresholds=thresholds,
operator="and",
nan_policy="fail",
)
assert labels_fail["label"].to_dict() == {10: "noise", 11: "noise", 12: "noise"}

labels_ignore = threshold_metrics_label_units(
metrics,
thresholds=thresholds,
operator="and",
nan_policy="ignore",
)
# unit 10: m1 ignored (NaN), m2 passes -> good
# unit 11: m2 fails -> noise
# unit 12: m1 ignored but m2 fails -> noise
assert labels_ignore["label"].to_dict() == {10: "good", 11: "noise", 12: "noise"}


def test_threshold_metrics_label_units_nan_policy_ignore_with_or():
import pandas as pd

metrics = pd.DataFrame(
{
"m1": [np.nan, -1.0],
"m2": [-1.0, -1.0],
},
index=[20, 21],
)
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}

labels_ignore_or = threshold_metrics_label_units(
metrics,
thresholds=thresholds,
operator="or",
nan_policy="ignore",
)
# unit 20: m1 is NaN and ignored; m2 fails => noise
# unit 21: both metrics fail => noise
assert labels_ignore_or["label"].to_dict() == {20: "noise", 21: "noise"}


def test_threshold_metrics_label_units_nan_policy_pass_and_or():
import pandas as pd

metrics = pd.DataFrame(
{
"m1": [np.nan, np.nan, 1.0, -1.0],
"m2": [1.0, -1.0, np.nan, np.nan],
},
index=[30, 31, 32, 33],
)
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}

labels_and = threshold_metrics_label_units(
metrics,
thresholds=thresholds,
operator="and",
nan_policy="pass",
)
# unit 30: m1 NaN (pass), m2 pass => good
# unit 31: m1 NaN (pass), m2 fail => noise
# unit 32: m1 pass, m2 NaN (pass) => good
# unit 33: m1 fail, m2 NaN (pass) => noise
assert labels_and["label"].to_dict() == {30: "good", 31: "noise", 32: "good", 33: "noise"}

labels_or = threshold_metrics_label_units(
metrics,
thresholds=thresholds,
operator="or",
nan_policy="pass",
)
# any NaN counts as pass => good unless all metrics fail without NaN
assert labels_or["label"].to_dict() == {30: "good", 31: "good", 32: "good", 33: "good"}


def test_threshold_metrics_label_units_invalid_operator_raises():
import pandas as pd

metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
thresholds = {"m1": {"min": 0.0}}
with pytest.raises(ValueError, match="operator must be 'and' or 'or'"):
threshold_metrics_label_units(metrics, thresholds, operator="xor")


def test_threshold_metrics_label_units_invalid_nan_policy_raises():
import pandas as pd

metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
thresholds = {"m1": {"min": 0.0}}
with pytest.raises(ValueError, match="nan_policy must be"):
threshold_metrics_label_units(metrics, thresholds, nan_policy="omit")


def test_threshold_metrics_label_units_missing_metric_raises():
import pandas as pd

metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
thresholds = {"does_not_exist": {"min": 0.0}}
with pytest.raises(ValueError, match="specified in thresholds are not present"):
threshold_metrics_label_units(metrics, thresholds)
124 changes: 124 additions & 0 deletions src/spikeinterface/curation/threshold_metrics_curation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import json
from pathlib import Path

import numpy as np

from spikeinterface.core.analyzer_extension_core import SortingAnalyzer

from .curation_tools import is_threshold_disabled


def threshold_metrics_label_units(
metrics: "pd.DataFrame",
thresholds: dict | str | Path,
pass_label: str = "good",
fail_label: str = "noise",
operator: str = "and",
nan_policy: str = "fail",
):
"""Label units based on metrics and thresholds.

Parameters
----------
metrics : pd.DataFrame
A DataFrame containing unit metrics with unit IDs as index.
thresholds : dict | str | Path
A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units.
Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values
should contain at least "min" and/or "max" keys to specify threshold ranges.
pass_label : str, default: "good"
The label to assign to units that pass all thresholds.
fail_label : str, default: "noise"
The label to assign to units that fail any threshold.
operator : "and" | "or", default: "and"
The logical operator to combine multiple metric thresholds. "and" means a unit must pass all thresholds to be
labeled as pass_label, while "or" means a unit must pass at least one threshold to be labeled as pass_label.
nan_policy : "fail" | "pass" | "ignore", default: "fail"
Policy for handling NaN values in metrics. If "fail", units with NaN values in any metric will be labeled as
fail_label. If "pass", units with NaN values in one metric will be labeled as pass_label.
If "ignore", NaN values will be ignored. Note that the "ignore" behavior will depend on the operator used.
If "and", NaNs will be treated as passing, since the initial mask is all true;
if "or", NaNs will be treated as failing, since the initial mask is all false.

Returns
-------
labels : pd.DataFrame
A DataFrame with unit IDs as index and a column 'label' containing the assigned labels (`fail_label` or `pass_label`)
"""
import pandas as pd

if not isinstance(metrics, pd.DataFrame):
raise ValueError("Only pd.DataFrame is supported for metrics.")

# Load thresholds from file if a path is provided
if isinstance(thresholds, (str, Path)):
with open(thresholds, "r") as f:
thresholds_dict = json.load(f)
elif isinstance(thresholds, dict):
thresholds_dict = thresholds
else:
raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.")

# Check that all specified metrics are present in the quality metrics DataFrame
missing_metrics = []
for metric in thresholds_dict.keys():
if metric not in metrics.columns:
missing_metrics.append(metric)
if len(missing_metrics) > 0:
raise ValueError(
f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. "
f"Available metrics are: {metrics.columns.tolist()}"
)

if operator not in ("and", "or"):
raise ValueError("operator must be 'and' or 'or'")

if nan_policy not in ("fail", "pass", "ignore"):
raise ValueError("nan_policy must be 'fail', 'pass', or 'ignore'")

labels = pd.DataFrame(index=metrics.index, dtype=str)
labels["label"] = fail_label

# Key change: init depends on operator
pass_mask = np.ones(len(metrics), dtype=bool) if operator == "and" else np.zeros(len(metrics), dtype=bool)
any_threshold_applied = False

for metric_name, threshold in thresholds_dict.items():
min_value = threshold.get("min", None)
max_value = threshold.get("max", None)

# If both disabled, ignore this metric
if is_threshold_disabled(min_value) and is_threshold_disabled(max_value):
continue

values = metrics[metric_name].to_numpy()
is_nan = np.isnan(values)

metric_ok = np.ones(len(values), dtype=bool)
if not is_threshold_disabled(min_value):
metric_ok &= values >= min_value
if not is_threshold_disabled(max_value):
metric_ok &= values <= max_value

# Handle NaNs
nan_mask = slice(None)
if nan_policy == "fail":
metric_ok &= ~is_nan
elif nan_policy == "pass":
metric_ok |= is_nan
else:
# if nan_policy == "ignore", we only set values for non-nan entries
nan_mask = ~is_nan

any_threshold_applied = True

if operator == "and":
pass_mask[nan_mask] &= metric_ok[nan_mask]
else:
pass_mask[nan_mask] |= metric_ok[nan_mask]

if not any_threshold_applied:
pass_mask[:] = True

labels.loc[pass_mask, "label"] = pass_label
return labels