-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmk_ete_jobs.py
82 lines (72 loc) · 2.78 KB
/
mk_ete_jobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python
head = "#!/bin/bash \n"
head += "#SBATCH --time 03:00:00\n"
head += "#SBATCH --mem 16GB\n"
head += "#SBATCH --cpus-per-task 4\n"
head += "#SBATCH --gres=gpu:1\n"
head += "#SBATCH --account=rrg-kyi\n"
head += "cd ~/src/eth \n"
def mk_script(args, out_size, vq_emb_num, vq_emb_dim):
dataset = args.dataset
# manifold = args.manifold
trans = args.transform
in_size = args.in_size
exp = "_res{}x{}_num{}_dim{}".format(
out_size, out_size, vq_emb_num, vq_emb_dim
)
files = {
"audio": "--audio_f {}/data/audio.h5 ".format(dataset),
"video": "--video_f {}/data/video.h5 ".format(dataset),
"depth": "--depth_f {}/data/disp.h5 ".format(dataset),
"seg": "--seg_f {}/data/seg.h5 ".format(dataset),
}
for frm, to in [("audio", "video"), ("audio", "depth"), ("audio", "seg")]:
script = head
script += "python -W ignore train.py "
script += "--mode endtoend "
script += "--dataset {} ".format(dataset)
script += files[frm]
script += files[to]
if dataset == "ytasmr":
script += "--mask_f {}/data/mask.h5 ".format(dataset)
script += "--frm {} --to {} ".format(frm, to)
# script += "--vq_frm {}/log/{}/vq{}/{}/ckpt.pyth ".format(
# dataset, in_size, exp, frm
# )
# script += "--vq_to {}/log/{}/vq{}/{}/ckpt.pyth ".format(
# dataset, in_size, exp, to
# )
script += "--vq_emb_num {} ".format(vq_emb_num)
script += "--vq_emb_dim {} ".format(vq_emb_dim)
script += "--batch_size 64 "
script += "--ifr 250 "
script += "--in_size {} --out_size {} ".format(in_size, out_size)
script += "--transform {} ".format(trans)
# script += "--manifold {} ".format(manifold)
# script += "--top_k {} ".format(16)
# script += "--restore "
script += "--cpus 4 "
script += "--log_dir {}/log/{}/ete_{}_{}/{}-{} \n".format(
dataset, in_size, trans, exp, frm, to
)
with open(
"{}/job/{}/ete_{}{}_{}-{}.sh".format(
dataset, in_size, trans, exp, frm, to
),
"w",
) as f:
f.write(script)
def main(args):
for out_size in [8, 16, 32]:
for vq_emb_num in [64]:
for vq_emb_dim in [64]:
print(args.dataset, out_size, vq_emb_num, vq_emb_dim)
mk_script(args, out_size, vq_emb_num, vq_emb_dim)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--in_size", type=int, required=True)
parser.add_argument("--transform", type=str, required="res")
args = parser.parse_args()
main(args)