diff --git a/ecephys_spike_sorting/modules/noise_templates/__main__.py b/ecephys_spike_sorting/modules/noise_templates/__main__.py index e0cc61e4..c4509179 100644 --- a/ecephys_spike_sorting/modules/noise_templates/__main__.py +++ b/ecephys_spike_sorting/modules/noise_templates/__main__.py @@ -7,7 +7,7 @@ from .id_noise_templates import id_noise_templates, id_noise_templates_rf -from ...common.utils import write_cluster_group_tsv, load_kilosort_data +from ...common.utils import write_cluster_group_tsv, load_kilosort_data, read_cluster_group_tsv def classify_noise_templates(args): @@ -41,10 +41,27 @@ def classify_noise_templates(args): cluster_ids, is_noise = id_noise_templates(cluster_ids, templates, np.squeeze(channel_map), \ args['noise_waveform_params']) - mapping = {False: 'good', True: 'noise'} - labels = [mapping[value] for value in is_noise] + #mapping = {False: 'good', True: 'noise'} + #labels = [mapping[value] for value in is_noise] + ci_tmp, cluster_group = read_cluster_group_tsv(os.path.join(args['directories']['kilosort_output_directory'], \ + 'cluster_KSLabel.tsv')) - write_cluster_group_tsv(cluster_ids, + print('KS output ' + args['directories']['kilosort_output_directory']) + + labels = [ ] + for i, ci in enumerate(ci_tmp): + if is_noise[cluster_ids==ci]: + labels.append('noise') + else: + labels.append(cluster_group[i]) + + #write_cluster_group_tsv(cluster_ids, + # labels, + # args['directories']['kilosort_output_directory'], + # args['ephys_params']['cluster_group_file_name']) + print(f"{sum([x=='good' for x in labels])} remaining good units") + + write_cluster_group_tsv(ci_tmp, labels, args['directories']['kilosort_output_directory'], args['ephys_params']['cluster_group_file_name'])