Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spikesorting v0 QualityMetrics._compute_metric removes units with num_spikes < min_spikes #1202

Open
sytseng opened this issue Dec 20, 2024 · 0 comments

Comments

@sytseng
Copy link

sytseng commented Dec 20, 2024

Describe the bug
In the make function of spikesorting.v0.spikesorting_curation.QualityMetrics, when we call _compute_metric on metric_name = "num_spikes", it returns metric dictionary without units that have num_spikes < min_spikes, whereas for all the other metrics, these units are preserved.

see line 609-622 in spikesorting.v0.spikesorting_curation.QualityMetrics._compute_metric:

        for unit_id in waveform_extractor.sorting.get_unit_ids():
            # checks to avoid bug in spikeinterface 0.98.2
            if num_spikes[unit_id] < min_spikes:
                if is_nn_iso:
                    metric[str(unit_id)] = (np.nan, np.nan)
                elif is_nn_overlap:
                    metric[str(unit_id)] = np.nan

            else:
                metric[str(unit_id)] = metric_func(
                    waveform_extractor,
                    this_unit_id=int(unit_id),
                    **metric_params,
                )

This triggers an error in QualityMetrics.make function in line 571-573

      key["object_id"] = AnalysisNwbfile().add_units_metrics(
            key["analysis_file_name"], metrics=qm
        )

because the number of rows for num_spikes will be smaller / inconsistent with other metrics when inserting the unit columns into the nwb file.

To Reproduce
Steps to reproduce the behavior:

  1. I was populating the QualityMetrics with the key:
key = {'curation_id': 0,
 'nwb_file_name': 'peanut20201107_.nwb',
 'sort_group_id': 27,
 'sort_interval_name': 'raw data valid times no premaze no home_v2',
 'preproc_params_name': 'default_min_seg',
 'team_name': 'JG_DG',
 'sorter': 'mountainsort4',
 'sorter_params_name': 'franklab_probe_ctx_30KHz_115rad_30clip',
 'artifact_removed_interval_list_name': 'peanut20201107_.nwb_raw data valid times no premaze no home_v2_27_default_min_seg_0.25_500_8_2_artifact_removed_valid_times',
 'waveform_params_name': 'default_whitened_20000spikes_20jobs',
 'metric_params_name': 'peak_offset_num_spikes_20000spikes_v2'}
  1. In this sort, there are 94 units, and unit 78 has num_spikes = 9. The output dictionary metric of QualityMetrics._compute_metric excluded this unit, so the number of units left in this metric = 93, whereas the number of units in other metrics = 94.
  2. This triggered an error when adding the metric into nwbfile:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 160
    158 metric_key.update({\"metric_params_name\": metric_params_name})
    159 MetricSelection.insert1(metric_key, skip_duplicates=True)
--> 160 QualityMetrics.populate([(MetricSelection & metric_key).proj()])
    162 # Perform automatic curation
    163 autocuration_key = metric_key.copy()

File ~/Src/spyglass/src/spyglass/utils/dj_mixin.py:608, in SpyglassMixin.populate(self, *restrictions, **kwargs)
    606 if use_transact:  # Pass single-process populate to super
    607     kwargs[\"processes\"] = processes
--> 608     return super().populate(*restrictions, **kwargs)
    609 else:  # No transaction protection, use bare make
    610     for key in keys:

File ~/miniforge3/envs/spyglass_sort/lib/python3.9/site-packages/datajoint/autopopulate.py:254, in AutoPopulate.populate(self, keys, suppress_errors, return_exception_objects, reserve_jobs, order, limit, max_calls, display_progress, processes, make_kwargs, *restrictions)
    248 if processes == 1:
    249     for key in (
    250         tqdm(keys, desc=self.__class__.__name__)
    251         if display_progress
    252         else keys
    253     ):
--> 254         status = self._populate1(key, jobs, **populate_kwargs)
    255         if status is True:
    256             success_list.append(1)

File ~/miniforge3/envs/spyglass_sort/lib/python3.9/site-packages/datajoint/autopopulate.py:322, in AutoPopulate._populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs)
    320 self.__class__._allow_insert = True
    321 try:
--> 322     make(dict(key), **(make_kwargs or {}))
    323 except (KeyboardInterrupt, SystemExit, Exception) as error:
    324     try:

File ~/Src/spyglass/src/spyglass/spikesorting/v0/spikesorting_curation.py:571, in QualityMetrics.make(self, key)
    568 logger.info(f\"Computed all metrics: {qm}\")
    569 self._dump_to_json(qm, key[\"quality_metrics_path\"])
--> 571 key[\"object_id\"] = AnalysisNwbfile().add_units_metrics(
    572     key[\"analysis_file_name\"], metrics=qm
    573 )
    574 AnalysisNwbfile().add(key[\"nwb_file_name\"], key[\"analysis_file_name\"])
    575 AnalysisNwbfile().log(key, table=self.full_table_name)

File ~/Src/spyglass/src/spyglass/common/common_nwbfile.py:652, in AnalysisNwbfile.add_units_metrics(self, analysis_file_name, metrics)
    650     logger.info(f\"Adding metric {metric_name} : {metric_dict}\")
    651     metric_data = list(metric_dict.values())
--> 652     nwbf.add_unit_column(
    653         name=metric_name, description=metric_name, data=metric_data
    654     )
    656 io.write(nwbf)
    657 return nwbf.units.object_id

File ~/miniforge3/envs/spyglass_sort/lib/python3.9/site-packages/hdmf/utils.py:668, in docval.<locals>.dec.<locals>.func_call(*args, **kwargs)
    666 def func_call(*args, **kwargs):
    667     pargs = _check_args(args, kwargs)
--> 668     return func(args[0], **pargs)

File ~/miniforge3/envs/spyglass_sort/lib/python3.9/site-packages/pynwb/file.py:757, in NWBFile.add_unit_column(self, **kwargs)
    752 \"\"\"
    753 Add a column to the unit table.
    754 See :py:meth:`~hdmf.common.table.DynamicTable.add_column` for more details
    755 \"\"\"
    756 self.__check_units()
--> 757 self.units.add_column(**kwargs)

File ~/miniforge3/envs/spyglass_sort/lib/python3.9/site-packages/hdmf/utils.py:668, in docval.<locals>.dec.<locals>.func_call(*args, **kwargs)
    666 def func_call(*args, **kwargs):
    667     pargs = _check_args(args, kwargs)
--> 668     return func(args[0], **pargs)

File ~/miniforge3/envs/spyglass_sort/lib/python3.9/site-packages/hdmf/common/table.py:925, in DynamicTable.add_column(self, **kwargs)
    922     col = col_index
    924 if len(col) != len(self.id):
--> 925     raise ValueError(\"column must have the same number of rows as 'id'\")
    926 self.__colids[name] = len(self.__df_cols)
    927 self.fields['colnames'] = tuple(list(self.colnames) + [name])

ValueError: column must have the same number of rows as 'id'"

where len(col) = 93 (for num_spikes) and len(self.id) = 94.

Expected behavior
A clear and concise description of what you expected to happen.
Suggested change:
In _compute_metrics, check if the input metric_name == "num_spikes", if so, directly return the output of line 603:
num_spikes = sq.compute_num_spikes(waveform_extractor)
as the output of the function.

@sytseng sytseng changed the title Spikesorting v0 QualityMetrics._compute_metric removes units with num_spikes < min_spikes and causes errors in AnalysisNwbfile.add_units_metrics Spikesorting v0 QualityMetrics._compute_metric removes units with num_spikes < min_spikes Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant