-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathtest.py
102 lines (73 loc) · 2.91 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import sys
import time
import copy
import shutil
import random
import pdb
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import config
import myutils
from torch.utils.data import DataLoader
##### Parse CmdLine Arguments #####
os.environ["CUDA_VISIBLE_DEVICES"]='0'
args, unparsed = config.get_args()
cwd = os.getcwd()
device = torch.device('cuda' if args.cuda else 'cpu')
torch.manual_seed(args.random_seed)
if args.cuda:
torch.cuda.manual_seed(args.random_seed)
if args.dataset == "vimeo90K_septuplet":
from dataset.vimeo90k_septuplet import get_loader
test_loader = get_loader('test', args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "ucf101":
from dataset.ucf101_test import get_loader
test_loader = get_loader(args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "davis":
from dataset.Davis_test import get_loader
test_loader = get_loader(args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "snu":
from dataset.snufilm import get_loader
test_loader = get_loader(args.test_mode , args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "gopro":
from dataset.GoPro import get_loader
test_loader = get_loader(args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers, test_mode=True, interFrames=args.n_outputs)
else:
raise NotImplementedError
from model.Unet_3D_3D_interpolate import UNet_3D_3D
print("Building model: %s"%args.model.lower())
model = UNet_3D_3D(args.model.lower() , n_inputs=args.nbr_frame, n_outputs=args.n_outputs, joinType=args.joinType)
# Just make every model to DataParallel
model = torch.nn.DataParallel(model).to(device)
print("#params" , sum([p.numel() for p in model.parameters()]))
def test(args):
time_taken = []
img_save_id = 0
losses, psnrs, ssims = myutils.init_meters(args.loss)
model.eval()
psnr_list = []
with torch.no_grad():
for i, (images, gt_image ) in enumerate(tqdm(test_loader)):
images = [img_.cuda() for img_ in images]
gt = [g_.cuda() for g_ in gt_image]
start_time = time.time()
out = model(images)
out = torch.cat(out)
gt = torch.cat(gt)
time_taken.append(time.time() - start_time)
myutils.eval_metrics(out, gt, psnrs, ssims)
print("PSNR: %f, SSIM: %f\n" %
(psnrs.avg, ssims.avg))
print("Time , " , sum(time_taken)/len(time_taken))
return
""" Entry Point """
def main(args):
assert args.load_from is not None
model_dict = model.state_dict()
model.load_state_dict(torch.load(args.load_from)["state_dict"] , strict=True)
test(args)
if __name__ == "__main__":
main(args)