-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_benchmark.py
136 lines (121 loc) · 7.37 KB
/
run_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# -*- coding: utf-8 -*-
# @Time : 2023/2/6 13:43
# @Author : Tory Deng
# @File : run_benchmark.py
# @Software: PyCharm
import os
from typing import List, Dict, Union, Literal, Callable, Optional
from loguru import logger
from ._metrics import compute_clustering_metrics
from ._recorder import create_records, write_records, store_metrics_to_records
from ._utils import rm_cache, set_logger
from .cluster import generally_cluster_obs
from .dataset import load_data
from .selection import generally_select_features
@logger.catch
def run_bench(
data_cfg: Dict[str, Dict[str, Union[os.PathLike, str, List, Dict[str, Union[List, str]]]]],
fs_cfg: Dict[Union[str, Callable], Union[List[int], List[Literal['auto']]]],
cl_cfg: Dict[str, int],
metrics: List[Literal['ARI', 'NMI']],
modality: Literal['scrna', 'spatial'],
fs_kwarg: Optional[Dict] = None,
cl_kwarg: Optional[Dict] = None,
preprocess: bool = True,
clean_cache: bool = False,
verbosity: Literal[0, 1, 2] = 2,
log_path: Optional[Union[os.PathLike, str]] = None,
random_state: int = 0
):
"""
The main function of benchmark. Before running it, you first need to prepare three configurations stored in
`dict` type: `data_cfg`, `fs_cfg`, and `cl_cfg`. See details of these configurations in parameter description.
You can also evaluate custom feature selection/cell clustering/domain detection functions. Results generated in
each step (preprocessed data, selected features, and cluster labels) are cached in `./cache` folder. At the end
of the benchmark run, it will create an `XLSX` file in the format `time modality.xlsx` which stores all evaluation
results and can be easily read by `pandas`.
Parameters
----------
data_cfg
Configurations of datasets. It should be a dict in the format `{'data_name': {'property_name': data_property}}`.
Supported property names are:
- 'adata_path': path to the `h5ad` file. The raw counts should be stored in `adata.X`. The benchmark will do
quality control and normalization automatically when `preprocess=True`. If you want to use datasets processed
by yourself, please set `preprocess=False`. In this case the normalized counts and raw data should be stored in
`adata.X` and `adata.raw` respectively.
- 'image_path': path to the `h5ad` file, optional. It will be ignored when `modality='scrna'`.
- 'annot_key' : a key in `adata.obs` that represents annotations (cell types/domains).
- 'to_replace': replace some values in `adata.obs['annot_key']`, optional. It should be a dict in the format
`{'value': ['to_replace_1', 'to_replace_2']}`, where 'to_replace_x' are values that will be replaced, and
'value' is value to replace any values matching 'to_replace_x' with.
- 'to_remove': remove cells/spots that have some values in `adata.obs['annot_key']`, optional. It should be a
list of values to be removed.
- 'batch': a key in `adata.obs` that represents batches, optional. Only valid when `modality='scrna'`. If it's
specified, the benchmark will use the algorithm in Seurat to combine the features selected in each batch.
- 'shape': Set 'hexagon' for Visium data, and 'square' for ST data when `modality='spatial'`, optional.
Default is 'hexagon'. Currently this parameter is only used in the cluster refinement of `spaGCN`.
fs_cfg
Configurations of feature selection methods. It should be a dict in the format
`{fs_method: list_of_numbers_of_selected_genes}`. More specifically,
- fs_method: can be either a string that represents a predefined function in this benchmark, or a custom
function. The benchmark will call the function like `custom_fs_function(adata, n_selected_genes, **kwargs)`,
and the return values must be an ndarray that contains features selected by the function. You can write a
wrapper function to work around incompatible parameters/return values.
- list_of_numbers_of_selected_genes: a list of numbers of genes needed to be selected. If the function
internally determines the number of selected genes (e.g. GeneClust and no feature selection), write the list
as `['auto']`.
cl_cfg
Configurations of downstream cell clustering/domain detection methods. It should be a dict in the format
`{cl_method: list_of_numbers_of_runs}`. More specifically,
- cl_method: can be either a string that represents a predefined function in this benchmark, or a custom
function. The benchmark will call the function like `custom_cl_function(fs_adata, img, **kwargs)`,
and the return values must be an ndarray that contains cluster labels generated by the function. the parameter
'img' is always `None` when `modality='scrna'`. You can write a wrapper function to work around incompatible
parameters/return values.
- list_of_numbers_of_runs: a list of numbers of times that the function will run with different random states.
metrics
Evaluation metrics. It should be a list contained metric names. Currently only support 'ARI' and 'NMI'.
modality
Which type pf data the benchmark will run on. Currently only support 'scrna' and 'spatial'.
fs_kwarg
Additional keyword arguments which will be passed to the custom feature selection function.
cl_kwarg
Additional keyword arguments which will be passed to the custom cell clustering/domain detection function.
preprocess
Whether to preprocess the dataset (including quality control and log-normalization).
clean_cache
Whether to clean all cached information, including the preprocessed data, selected genes and
generated cluster labels stored in `./cache` folder.
verbosity
0: only print warnings and errors
1: also print info
2: also print debug messages
log_path
Path to the log file.
random_state
Change to use different initial states for the optimization.
Returns
-------
None
"""
set_logger(verbosity, log_path)
if cl_kwarg is None:
cl_kwarg = dict()
if fs_kwarg is None:
fs_kwarg = dict()
if clean_cache:
rm_cache("./cache")
records = create_records(data_cfg, fs_cfg, cl_cfg, metrics)
for data_name, data_props in data_cfg.items():
adata, img = load_data(data_name, data_props, modality, preprocess) # img is actually None for scRNA-seq data
for fs_method, n_genes_list in fs_cfg.items():
for n_selected_genes in n_genes_list:
selected_genes = generally_select_features(adata, img, fs_method, n_selected_genes, modality, random_state, **fs_kwarg)
fs_adata = adata[:, selected_genes].copy()
for cl_method, n_runs in cl_cfg.items():
for run in range(n_runs):
cluster_labels = generally_cluster_obs(fs_adata, img, fs_method, n_selected_genes, cl_method, random_state, run, modality, **cl_kwarg)
for metric in metrics:
value = compute_clustering_metrics(fs_adata.obs[fs_adata.uns['annot_key']], cluster_labels, metric)
store_metrics_to_records(records, metric, value, data_name, cl_method, run, fs_method, n_selected_genes)
write_records(records, modality)