-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compare-diffusion-vs-interpolant.py
75 lines (63 loc) · 2.94 KB
/
compare-diffusion-vs-interpolant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# this script performs following steps:
# 1. loads a diffusion model from wandb
# 2. evaluates it on the test set
# 3. converts the model to an interpolant
# 4. evaluates the interpolant on the test set
# 5. prints the results
import argparse, os, sys, json
# add the root folder of the project to the path
sys.path.append(os.path.abspath(os.path.dirname(__file__) + '/../'))
from Utils.utils import setupGPU, load_config, merge_configs, JSONHelper
setupGPU() # call it on startup to prevent OOM errors on my machine
from Utils import dataset_from_config
from NN import model_from_config
from Utils.WandBUtils import CWBRun
from NN.restorators import replace_diffusion_restorator_by_interpolant
def modelProviderFromArgs(args, config):
run = CWBRun(args.wandb_run, None, 'tmp')
modelConfigs = merge_configs(run.config, config) # override config with run config
restorator = modelConfigs['model']['restorator']
assert restorator['name'] == 'diffusion', 'Must be a diffusion model'
weights = run.models()[-2].pathTo()
modelNet = model_from_config(modelConfigs['model']) # all models in the run should have the same config
modelNet.load_weights(weights)
yield(modelNet, run.name, run.fullId)
# convert to interpolant
modelConfigs['model']['restorator'] = replace_diffusion_restorator_by_interpolant(restorator)
modelNet = model_from_config(modelConfigs['model'])
modelNet.load_weights(weights)
yield(modelNet, run.name + ' (interpolant)', run.fullId)
return
def datasetProviderFromArgs(args, configs):
def datasetProvider():
dataset = dataset_from_config(configs['dataset'])
return dataset.make_dataset(configs['dataset']['test'], split='test')
return datasetProvider
def evaluateModel(model, datasetProvider):
return model.evaluate(datasetProvider(), return_dict=True, verbose=1)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate best model from wandb run with different DDIM parameters')
parser.add_argument(
'--config', type=str, required=True,
help='Path to a single config file or a multiple config files (they will be merged in order of appearance)',
default=[], action='append',
)
parser.add_argument('--wandb-run', type=str, help='Wandb run full id (entity/project/run_id)', required=True)
args = parser.parse_args()
configs = load_config(args.config, folder=os.getcwd())
assert 'dataset' in configs, 'No dataset config found'
datasetProvider = datasetProviderFromArgs(args, configs)
results = []
for model, modelName, runId in modelProviderFromArgs(args, configs):
print(f'Starting evaluation of model "{modelName}" ({runId})')
losses = evaluateModel(model, datasetProviderFromArgs(args, configs))
results.append(dict(**losses, model=modelName, runId=runId))
print()
continue
# print results
results = sorted(results, key=lambda x: x['loss'])
for r in results:
print(f'{r["model"]} ({r["runId"]}) | loss: {r["loss"]}')
continue
print()
pass