-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_sta.py
109 lines (56 loc) · 2.4 KB
/
run_sta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import STAGATE_pyG as STAGATE
import scanpy as sc
import anndata as ad
import pandas as pd
import numpy as np
import time
from Taichi.model import Taichi
def stagate(adata_list):
for adata in adata_list:
STAGATE.utils.Cal_Spatial_Net(adata, rad_cutoff=150)
adata = ad.concat(adata_list, label='slice_id')
adata.uns['Spatial_Net'] = pd.concat([adata_list[i].uns['Spatial_Net'] for i in range(len(adata_list))])
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata = STAGATE.Train_STAGATE.train_STAGATE(adata)
return adata
ctrl_adata = sc.read_h5ad('merfish_control.h5ad')
full_adata = sc.read_h5ad('merfish_condition.h5ad')
res_list = []
for i in full_adata.obs['slice_id'].unique():
print(f'running {i}')
cond_adata = full_adata[full_adata.obs['slice_id'] == i].copy()
import anndata as ad
ctrl_adata.obs['condition'] = 0
cond_adata.obs['condition'] = 1
run_adata = stagate([ctrl_adata, cond_adata])
run_adata.obs['condition'] = run_adata.obs['condition'].astype('category')
start_time = time.time()
model = Taichi(run_adata, ct_obs='cell_type', slice_id='slice_id')
model.label_refinement(use_rep='STAGATE')
res = model.graph_diffusion()
res = res[res.obs['condition'] == 1]
end_time = time.time()
print(f'Total Running Time {end_time - start_time}')
res_list.append(res)
ad.concat(res_list, label='slice_id').write_h5ad('taichi_stagate_merfish.h5ad')
adata = sc.read_h5ad('gt_starmap.h5ad')
res_list = []
for i in adata.obs['slice_id'].unique():
cond_adata = adata[adata.obs['slice_id'] == i].copy()
ctrl_adata = cond_adata[cond_adata.obs['Region'].isin([2, 3])].copy()
ctrl_adata.obs['condition'] = 0
cond_adata.obs['condition'] = 1
start_time = time.time()
run_adata = stagate([ctrl_adata, cond_adata])
run_adata.obs['condition'] = run_adata.obs['condition'].astype('category')
model = Taichi(run_adata, ct_obs='ct', slice_id='slice_id')
model.label_refinement(use_rep='STAGATE')
res = model.adata
res = model.graph_diffusion()
res = res[res.obs['condition'] == 1].copy()
end_time = time.time()
res_list.append(res)
print(f'Total Running Time {end_time - start_time}')
ad.concat(res_list, label='slice_id').write_h5ad('taichi_stagate_starmap.h5ad')