-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathmain.py
56 lines (43 loc) · 1.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import os
from datetime import datetime
from config import Config
from util.trainingprocess_stage1 import TrainingProcessStage1
from util.trainingprocess_stage3 import TrainingProcessStage3
from util.knn import KNN
def main():
# hardware constraint for speed test
torch.set_num_threads(1)
os.environ['OMP_NUM_THREADS'] = '1'
# initialization
config = Config()
torch.manual_seed(config.seed)
print('Start time: ', datetime.now().strftime('%H:%M:%S'))
# stage1 training
print('Training start [Stage1]')
model_stage1= TrainingProcessStage1(config)
for epoch in range(config.epochs_stage1):
print('Epoch:', epoch)
model_stage1.train(epoch)
print('Write embeddings')
model_stage1.write_embeddings()
print('Stage 1 finished: ', datetime.now().strftime('%H:%M:%S'))
# KNN
print('KNN')
KNN(config, neighbors = 30, knn_rna_samples=20000)
print('KNN finished: ', datetime.now().strftime('%H:%M:%S'))
# stage3 training
print('Training start [Stage3]')
model_stage3 = TrainingProcessStage3(config)
for epoch in range(config.epochs_stage3):
print('Epoch:', epoch)
model_stage3.train(epoch)
print('Write embeddings [Stage3]')
model_stage3.write_embeddings()
print('Stage 3 finished: ', datetime.now().strftime('%H:%M:%S'))
# KNN
print('KNN stage3')
KNN(config, neighbors = 30, knn_rna_samples=20000)
print('KNN finished: ', datetime.now().strftime('%H:%M:%S'))
if __name__ == "__main__":
main()