diff --git a/doc/api.rst b/doc/api.rst index adfdb85470..f4a97caabe 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -373,6 +373,7 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes + .. autofunction:: threshold_metrics_label_units .. autofunction:: model_based_label_units .. autofunction:: load_model .. autofunction:: train_model diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 730481937c..e00629086b 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,6 +20,7 @@ from .sortingview_curation import apply_sortingview_curation # automated curation +from .threshold_metrics_curation import threshold_metrics_label_units from .model_based_curation import model_based_label_units, load_model, auto_label_units from .train_manual_curation import train_model, get_default_classifier_search_spaces from .unitrefine_curation import unitrefine_label_units diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index f1d4eba3b5..3b5cb046f6 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -14,6 +14,15 @@ _methods_numpy = ("keep_first", "random", "keep_last") +def is_threshold_disabled(value): + """Check if a threshold value is disabled (None or np.nan).""" + if value is None: + return True + if isinstance(value, float) and np.isnan(value): + return True + return False + + def _find_duplicated_spikes_numpy( spike_train: np.ndarray, censored_period: int, diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py new file mode 100644 index 0000000000..82e0400b29 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -0,0 +1,221 @@ +import pytest +import json + +import numpy as np + +from spikeinterface.curation import threshold_metrics_label_units + + +def test_threshold_metrics_label_units_with_dataframe(): + import pandas as pd + + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0, 5.0], + "firing_rate": [0.5, 0.2, 25.0], + }, + index=[0, 1, 2], + ) + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1, "max": 20.0}, + } + + labels = threshold_metrics_label_units(metrics, thresholds) + + assert "label" in labels.columns + assert labels.shape[0] == len(metrics.index) + assert labels["label"].to_dict() == {0: "good", 1: "noise", 2: "noise"} + + +def test_threshold_metrics_label_units_with_file(tmp_path): + import pandas as pd + + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0], + "firing_rate": [0.5, 0.05], + }, + index=[0, 1], + ) + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1}, + } + + thresholds_file = tmp_path / "thresholds.json" + with open(thresholds_file, "w") as f: + json.dump(thresholds, f) + + labels = threshold_metrics_label_units(metrics, thresholds_file) + + assert labels["label"].to_dict() == {0: "good", 1: "noise"} + + +def test_threshold_metrics_label_external_labels(): + import pandas as pd + + metrics = pd.DataFrame( + { + "snr": [6.0, 4.0], + "firing_rate": [0.5, 0.05], + }, + index=[0, 1], + ) + thresholds = { + "snr": {"min": 5.0}, + "firing_rate": {"min": 0.1}, + } + + labels = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + pass_label="accepted", + fail_label="rejected", + ) + assert set(labels["label"]).issubset({"accepted", "rejected"}) + + +def test_threshold_metrics_label_units_operator_or_with_dataframe(): + import pandas as pd + + metrics = pd.DataFrame( + { + "m1": [1.0, 1.0, -1.0, -1.0], + "m2": [1.0, -1.0, 1.0, -1.0], + }, + index=[0, 1, 2, 3], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_and = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="and", + ) + assert labels_and.index.equals(metrics.index) + assert labels_and["label"].to_dict() == {0: "good", 1: "noise", 2: "noise", 3: "noise"} + + labels_or = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="or", + ) + assert labels_or.index.equals(metrics.index) + assert labels_or["label"].to_dict() == {0: "good", 1: "good", 2: "good", 3: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): + import pandas as pd + + metrics = pd.DataFrame( + { + "m1": [np.nan, 1.0, np.nan], + "m2": [1.0, -1.0, -1.0], + }, + index=[10, 11, 12], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_fail = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="and", + nan_policy="fail", + ) + assert labels_fail["label"].to_dict() == {10: "noise", 11: "noise", 12: "noise"} + + labels_ignore = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="and", + nan_policy="ignore", + ) + # unit 10: m1 ignored (NaN), m2 passes -> good + # unit 11: m2 fails -> noise + # unit 12: m1 ignored but m2 fails -> noise + assert labels_ignore["label"].to_dict() == {10: "good", 11: "noise", 12: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): + import pandas as pd + + metrics = pd.DataFrame( + { + "m1": [np.nan, -1.0], + "m2": [-1.0, -1.0], + }, + index=[20, 21], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_ignore_or = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="or", + nan_policy="ignore", + ) + # unit 20: m1 is NaN and ignored; m2 fails => noise + # unit 21: both metrics fail => noise + assert labels_ignore_or["label"].to_dict() == {20: "noise", 21: "noise"} + + +def test_threshold_metrics_label_units_nan_policy_pass_and_or(): + import pandas as pd + + metrics = pd.DataFrame( + { + "m1": [np.nan, np.nan, 1.0, -1.0], + "m2": [1.0, -1.0, np.nan, np.nan], + }, + index=[30, 31, 32, 33], + ) + thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + + labels_and = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="and", + nan_policy="pass", + ) + # unit 30: m1 NaN (pass), m2 pass => good + # unit 31: m1 NaN (pass), m2 fail => noise + # unit 32: m1 pass, m2 NaN (pass) => good + # unit 33: m1 fail, m2 NaN (pass) => noise + assert labels_and["label"].to_dict() == {30: "good", 31: "noise", 32: "good", 33: "noise"} + + labels_or = threshold_metrics_label_units( + metrics, + thresholds=thresholds, + operator="or", + nan_policy="pass", + ) + # any NaN counts as pass => good unless all metrics fail without NaN + assert labels_or["label"].to_dict() == {30: "good", 31: "good", 32: "good", 33: "good"} + + +def test_threshold_metrics_label_units_invalid_operator_raises(): + import pandas as pd + + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"m1": {"min": 0.0}} + with pytest.raises(ValueError, match="operator must be 'and' or 'or'"): + threshold_metrics_label_units(metrics, thresholds, operator="xor") + + +def test_threshold_metrics_label_units_invalid_nan_policy_raises(): + import pandas as pd + + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"m1": {"min": 0.0}} + with pytest.raises(ValueError, match="nan_policy must be"): + threshold_metrics_label_units(metrics, thresholds, nan_policy="omit") + + +def test_threshold_metrics_label_units_missing_metric_raises(): + import pandas as pd + + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"does_not_exist": {"min": 0.0}} + with pytest.raises(ValueError, match="specified in thresholds are not present"): + threshold_metrics_label_units(metrics, thresholds) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py new file mode 100644 index 0000000000..32a6a48d91 --- /dev/null +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -0,0 +1,124 @@ +import json +from pathlib import Path + +import numpy as np + +from spikeinterface.core.analyzer_extension_core import SortingAnalyzer + +from .curation_tools import is_threshold_disabled + + +def threshold_metrics_label_units( + metrics: "pd.DataFrame", + thresholds: dict | str | Path, + pass_label: str = "good", + fail_label: str = "noise", + operator: str = "and", + nan_policy: str = "fail", +): + """Label units based on metrics and thresholds. + + Parameters + ---------- + metrics : pd.DataFrame + A DataFrame containing unit metrics with unit IDs as index. + thresholds : dict | str | Path + A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. + Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values + should contain at least "min" and/or "max" keys to specify threshold ranges. + pass_label : str, default: "good" + The label to assign to units that pass all thresholds. + fail_label : str, default: "noise" + The label to assign to units that fail any threshold. + operator : "and" | "or", default: "and" + The logical operator to combine multiple metric thresholds. "and" means a unit must pass all thresholds to be + labeled as pass_label, while "or" means a unit must pass at least one threshold to be labeled as pass_label. + nan_policy : "fail" | "pass" | "ignore", default: "fail" + Policy for handling NaN values in metrics. If "fail", units with NaN values in any metric will be labeled as + fail_label. If "pass", units with NaN values in one metric will be labeled as pass_label. + If "ignore", NaN values will be ignored. Note that the "ignore" behavior will depend on the operator used. + If "and", NaNs will be treated as passing, since the initial mask is all true; + if "or", NaNs will be treated as failing, since the initial mask is all false. + + Returns + ------- + labels : pd.DataFrame + A DataFrame with unit IDs as index and a column 'label' containing the assigned labels (`fail_label` or `pass_label`) + """ + import pandas as pd + + if not isinstance(metrics, pd.DataFrame): + raise ValueError("Only pd.DataFrame is supported for metrics.") + + # Load thresholds from file if a path is provided + if isinstance(thresholds, (str, Path)): + with open(thresholds, "r") as f: + thresholds_dict = json.load(f) + elif isinstance(thresholds, dict): + thresholds_dict = thresholds + else: + raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.") + + # Check that all specified metrics are present in the quality metrics DataFrame + missing_metrics = [] + for metric in thresholds_dict.keys(): + if metric not in metrics.columns: + missing_metrics.append(metric) + if len(missing_metrics) > 0: + raise ValueError( + f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. " + f"Available metrics are: {metrics.columns.tolist()}" + ) + + if operator not in ("and", "or"): + raise ValueError("operator must be 'and' or 'or'") + + if nan_policy not in ("fail", "pass", "ignore"): + raise ValueError("nan_policy must be 'fail', 'pass', or 'ignore'") + + labels = pd.DataFrame(index=metrics.index, dtype=str) + labels["label"] = fail_label + + # Key change: init depends on operator + pass_mask = np.ones(len(metrics), dtype=bool) if operator == "and" else np.zeros(len(metrics), dtype=bool) + any_threshold_applied = False + + for metric_name, threshold in thresholds_dict.items(): + min_value = threshold.get("min", None) + max_value = threshold.get("max", None) + + # If both disabled, ignore this metric + if is_threshold_disabled(min_value) and is_threshold_disabled(max_value): + continue + + values = metrics[metric_name].to_numpy() + is_nan = np.isnan(values) + + metric_ok = np.ones(len(values), dtype=bool) + if not is_threshold_disabled(min_value): + metric_ok &= values >= min_value + if not is_threshold_disabled(max_value): + metric_ok &= values <= max_value + + # Handle NaNs + nan_mask = slice(None) + if nan_policy == "fail": + metric_ok &= ~is_nan + elif nan_policy == "pass": + metric_ok |= is_nan + else: + # if nan_policy == "ignore", we only set values for non-nan entries + nan_mask = ~is_nan + + any_threshold_applied = True + + if operator == "and": + pass_mask[nan_mask] &= metric_ok[nan_mask] + else: + pass_mask[nan_mask] |= metric_ok[nan_mask] + + if not any_threshold_applied: + pass_mask[:] = True + + labels.loc[pass_mask, "label"] = pass_label + return labels