-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_config.py
93 lines (74 loc) · 2.39 KB
/
run_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import json
import random
import numpy as np
import torch
from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast
from utils.print_args import print_args
def get_setting(args, ii):
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_expand{}_dc{}_fc{}_eb{}_dt{}_{}_{}'.format(
args.task_name,
args.model_id,
args.model,
args.data,
args.features,
args.seq_len,
args.label_len,
args.pred_len,
args.d_model,
args.n_heads,
args.e_layers,
args.d_layers,
args.d_ff,
args.expand,
args.d_conv,
args.factor,
args.embed,
args.distil,
args.des, ii)
return setting
def load_config(config_path):
with open(config_path, 'r') as f:
args = f.read()
args = argparse.Namespace(**json.loads(args))
return args
if __name__ == '__main__':
fix_seed = 2021
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)
config_path = '{your chosen config file path}'
args = load_config(config_path)
args.use_gpu = True \
if (torch.cuda.is_available()
or torch.backends.mps.is_available()) \
else False
print(args.use_gpu)
if args.use_gpu and args.use_multi_gpu:
args.devices = args.devices.replace(' ', '')
device_ids = args.devices.split(',')
args.device_ids = [int(id_) for id_ in device_ids]
args.gpu = args.device_ids[0]
print('Args in experiment:')
print_args(args)
if args.task_name == 'long_term_forecast':
Exp = Exp_Long_Term_Forecast
else:
exit()
if args.is_training:
for ii in range(args.itr):
# setting record of experiments
exp = Exp(args) # set experiments
setting = get_setting(args, ii)
print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
exp.train(setting)
print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
exp.test(setting)
torch.cuda.empty_cache()
else:
ii = 0
setting = get_setting(args, ii)
exp = Exp(args) # set experiments
print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
exp.test(setting, test=1)
torch.cuda.empty_cache()