forked from Aytien/DLGN_Models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscript1.py
131 lines (122 loc) · 3.91 KB
/
script1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from training_methods import train_model
from DLGN_kernel import *
from argparse import Namespace
from data_gen import *
from sklearn.model_selection import train_test_split
from DLGN_enums import *
import wandb
if __name__ == "__main__":
# Set the random seed for reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DATASET = "dataset3b"
DATA_DIR = 'data/' + DATASET
data = {}
data['train_data'] = torch.tensor(np.load(DATA_DIR + '/x_train.npy'))
data['train_labels'] = torch.tensor(np.load(DATA_DIR + '/y_train.npy'))
data['test_data'] = torch.tensor(np.load(DATA_DIR + '/x_test.npy'))
data['test_labels'] = torch.tensor(np.load(DATA_DIR + '/y_test.npy'))
data_config = np.load(DATA_DIR + '/config.npy', allow_pickle=True).item()
WANDB_NOTEBOOK_NAME = 'DLGN_VT'
WANDB_PROJECT_NAME = 'DLGN_KERNEL_BTP'
WANDB_ENTITY = 'cs20b004'
wandb.login()
sweep_config = {
"name": "VT_solvers",
"method": "grid",
"parameters": {
"num_hidden_nodes": {
"values": [[10]*4]
},
"beta": {
"values": [50]
},
"alpha_init": {
"values": [None]
},
"log_features": {
"values": [False]
},
"BN": {
"values": [True]
},
"prod":{
"values":['op']
},
"lr": {
"values": [0.001]
},
"epochs":{
"values": [1000]
},
"reg": {
"values": [0.1]
},
"value_freq": {
"values": [100]
},
"num_iter": {
"values": [5e4]
},
"weight_decay": {
"values": [0.01]
},
"use_wandb": {
"values": [True]
},
"feat": {
"values": ['cf']
},
"vt_fit1": {
"values": [2,3,4,5]
}
}
}
sweep_id = wandb.sweep(sweep_config, entity=WANDB_ENTITY, project=WANDB_PROJECT_NAME)
const_config = {
"device" : device,
"model_type" : ModelTypes.VT,
"loss_fn_type" : LossTypes.HINGE,
"value_scale" : 500,
"optimizer_type" : Optim.ADAM,
"num_data" : len(data['train_data']),
"dim_in" : data_config.dim_in,
"mode" : "pwc",
"save_freq" : 100,
}
def wb_sweep_sf():
run = wandb.init()
config = wandb.config
if config.vt_fit1 == 1:
const_config["vt_fit"] = VtFit.LOGISTIC
config.reg = 0.1
elif config.vt_fit1 == 2:
const_config["vt_fit"] = VtFit.LINEARSVC
const_config["loss_fn_type"] = LossTypes.CE
config.reg = 0.1
elif config.vt_fit1 == 3:
const_config["vt_fit"] = VtFit.PEGASOS
const_config["loss_fn_type"] = LossTypes.HINGE
config.reg = 0.1
elif config.vt_fit1 == 4:
const_config["vt_fit"] = VtFit.NPKSVC
const_config["loss_fn_type"] = LossTypes.HINGE
config.reg = 0.001
elif config.vt_fit1 == 5:
const_config["vt_fit"] = VtFit.PEGASOSKERNEL
const_config["loss_fn_type"] = LossTypes.HINGE
config.reg = 0.1
config = {**config, **const_config}
config = Namespace(**config)
filename_suffx = str(config.vt_fit1)
run.name = filename_suffx
model = train_model(data, config)
run.finish()
torch.cuda.empty_cache()
return
wandb.agent(sweep_id, wb_sweep_sf, entity=WANDB_ENTITY, project=WANDB_PROJECT_NAME)
wandb.finish()