-
Notifications
You must be signed in to change notification settings - Fork 1
/
grid_2_print.py
executable file
·101 lines (84 loc) · 2.65 KB
/
grid_2_print.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
"""Grid search 2.
For MNIST <-> Fashion MNIST
"""
import re
shared_pattern = """\
--default_scratch "~/workspace/scratch/latent_transfer/" \
--config_A "mnist_0_nlatent100" --config_B "fashion_mnist_0_nlatent100" \
--config_classifier_A "mnist_classifier_0" --config_classifier_B "fashion_mnist_classifier_0" \
--n_latent 100 --n_latent_shared 8 \
--layers "512,512,512,512" \
--cls_layers "," \
--prior_loss_beta {plb} \
--unsup_align_loss_beta {ualb} \
--cls_loss_beta {clb} \
--n_sup {ns} \
--sig_extra "grid_2" \
--n_iters {ni} \
--use_interpolated {ui} \
--post_mortem=false \
"""
train_pattern = """\
run_ml_docker --no-it python3 ./train_joint2_mnist_family.py \
""" + shared_pattern
eval_pattern_list = [
"""\
run_ml_docker --no-it python3 ./evaluate_joint2_mnist_family.py \
""" + " " + shared_pattern + " " + """\
--load_ckpt_iter -1 \
--interpolate_labels "8,8,8,8,8,8,8" \
--nb_images_between_labels 4 \
--random_seed 1145141925 \
""", """\
run_ml_docker --no-it python3 ./evaluate_joint2_mnist_family.py \
""" + " " + shared_pattern + " " + """\
--load_ckpt_iter -1 \
--interpolate_labels "6,6,6,6,6,6,6" \
--nb_images_between_labels 4 \
--random_seed 1145141925 \
""", """\
run_ml_docker --no-it python3 ./evaluate_joint2_mnist_family.py \
""" + " " + shared_pattern + " " + """\
--load_ckpt_iter -1 \
--interpolate_labels "7,7,7,7,7,7,7" \
--nb_images_between_labels 4 \
--random_seed 1145141925 \
""", """\
run_ml_docker --no-it python3 ./evaluate_joint2_mnist_family.py \
""" + " " + shared_pattern + " " + """\
--load_ckpt_iter -1 \
--interpolate_labels "0,1,7,8,9,3,2" \
--nb_images_between_labels 4 \
--random_seed 1145141925 \
"""
]
n_latent_shared = 8
plb_base = 0.005
ualb_base = 1.0
clb_base = 0.05
train_cmds = []
eval_cmds = []
def add(plb, ualb, clb, ns, ni, ui):
cmd = train_pattern.format(plb=plb, ualb=ualb, clb=clb, ns=ns, ni=ni, ui=ui)
cmd = re.sub(' +', ' ', cmd)
# train_cmds.append(cmd)
for eval_pattern in eval_pattern_list:
cmd = eval_pattern.format(plb=plb, ualb=ualb, clb=clb, ns=ns, ni=ni, ui=ui)
cmd = re.sub(' +', ' ', cmd)
eval_cmds.append(cmd)
def main():
for plb in [plb_base]:
for ualb in [ualb_base]:
for clb in [clb_base]:
# for ns in [-1, 0, 10, 100, 1000, 10000]:
for ns in [-1]:
for ni in [20000, 50000]:
for ui in ['none']:
if (clb == 0 and ns != -1) or (clb > 0.0 and ns == 0):
continue # no need to waste
add(plb=plb, ualb=ualb, clb=clb, ns=ns, ni=ni, ui=ui)
for _ in train_cmds + eval_cmds:
print(_)
if __name__ == '__main__':
main()