-
Notifications
You must be signed in to change notification settings - Fork 0
/
MainCondition.py
39 lines (35 loc) · 1.05 KB
/
MainCondition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from DiffusionFreeGuidence.TrainCondition import train, eval
def main(model_config=None):
modelConfig = {
"state": "train", # or eval
"epoch": 70,
"batch_size": 80,
"T": 500,
"channel": 128,
"channel_mult": [1, 2, 2, 2],
"num_res_blocks": 2,
"dropout": 0.15,
"lr": 1e-4,
"multiplier": 2.5,
"beta_1": 1e-4,
"beta_T": 0.028,
"img_size": 32,
"grad_clip": 1.,
"device": "cuda:0",
"w": 1.8,
"save_dir": "./CheckpointsCondition/",
"training_load_weight": None,
"test_load_weight": "ckpt_63_.pt",
"sampled_dir": "./SampledImgs/",
"sampledNoisyImgName": "NoisyGuidenceImgs.png",
"sampledImgName": "SampledGuidenceImgs.png",
"nrow": 8
}
if model_config is not None:
modelConfig = model_config
if modelConfig["state"] == "train":
train(modelConfig)
else:
eval(modelConfig)
if __name__ == '__main__':
main()