-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
51 lines (43 loc) · 1.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from utils import train
from utils.BalancedDataParallel import BalancedDataParallel
from model import model as utemponet
batch_size = 16
lr = 1e-3
MAX_EPOCH = 150
NUM_WORKERS = 4
GPU0_BSZ = 8
ACC_GRAD = 1
IN_CHANNELS = 4
NUM_CLASSES = 8
NUM_LAYERS = 2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = utemponet.UTempoNet(
IN_CHANNELS,
NUM_CLASSES,
NUM_LAYERS,
)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = BalancedDataParallel(GPU0_BSZ // ACC_GRAD,model,device_ids=[0,1],output_device=0)
train_folder = r'**/train'
val_folder = r'**/val'
model_name = 'model_name'
model = model.to(device)
train_kwargs = dict({'net':model,
'devices':device,
'batchsize':batch_size,
'lr':lr,
'num_classes':NUM_CLASSES,
'max_epoch':MAX_EPOCH,
'train_folder':train_folder,
'val_folder':val_folder,
'num_workers':NUM_WORKERS,
'data_aug': True,
'model_name':model_name,
'resume':False,
'hyper_params':{'th_a':0.99,
'th_b':0.15}
})
train.train_model(**train_kwargs)