-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_single.py
128 lines (103 loc) · 3.85 KB
/
test_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#encoding:utf-8
import os, random, copy
import numpy as np
import torch
import torch.nn as nn
import argparse
import yaml
import shutil
import tensorboard_logger as tb_logger
import logging
import click
import utils
import data
import engine
from vocab import deserialize_vocab
def parser_options():
# Hyper Parameters setting
parser = argparse.ArgumentParser()
parser.add_argument('--path_opt', default='option/ERA_VCSR.yaml', type=str,
help='path to a yaml options file')
parser.add_argument('--resume', default='checkpoint/era_aba_mv_fliter_new13_i102a1_de_agg_div_ra1/1/AMFMN_best.pth.tar', type=str,
help='path to a yaml options file')
opt = parser.parse_args()
# load model options
with open(opt.path_opt, 'r') as handle:
options = yaml.load(handle,Loader=yaml.FullLoader)
options['optim']['resume'] = opt.resume
return options
def main(options):
# choose model
if options['model']['name'] == "VCSR":
from layers import VCSR as models
else:
raise NotImplementedError
# make vocab
vocab = deserialize_vocab(options['dataset']['vocab_path'])
vocab_word = sorted(vocab.word2idx.items(), key=lambda x: x[1], reverse=False)
vocab_word = [tup[0] for tup in vocab_word]
# Create dataset, model, criterion and optimizer
test_loader = data.get_test_loader(vocab, options)
model = models.factory(options['model'],
vocab_word,
cuda=True,
data_parallel=False)
print('Model has {} parameters'.format(utils.params_count(model)))
# optionally resume from a checkpoint
if os.path.isfile(options['optim']['resume']):
print("=> loading checkpoint '{}'".format(options['optim']['resume']))
checkpoint = torch.load(options['optim']['resume'],map_location=torch.device("cpu"))
start_epoch = checkpoint['epoch']
best_rsum = checkpoint['best_rsum']
model.load_state_dict(checkpoint['model'])
else:
print("=> no checkpoint found at '{}'".format(options['optim']['resume']))
# evaluate on test set
sims = engine.validate_test(test_loader, model)
return sims
def update_options_savepath(options, k):
updated_options = copy.deepcopy(options)
updated_options['optim']['resume'] = options['logs']['ckpt_save_path'] + options['k_fold']['experiment_name'] + "/" \
+ str(k) + "/" + options['model']['name'] + '_best.pth.tar'
return updated_options
if __name__ == '__main__':
options = parser_options()
# run experiment
one_sims = main(options)
#print(one_sims.shape)
#rs = np.argsort(one_sims.transpose(),axis=1)+1#.transpose()
'''
#img-text
c=1
result=""
for i in rs[:,-5:]:
result+=str(c)+"-"+str(c+4)+" "+str(i[::-1])+"\n"
c+=5
'''
'''
# text-img #need to t()
c=1
im=1
result=""
for i in rs[:,-5:]:
result+=str(c)+"-"+str(im)+" "+str(i[::-1])+"\n"
c+=1
if c%6==0:
im+=1
f=open("ERA_result_text_img.txt","a+")
f.write(result)
'''
# ave
last_sims = one_sims
# get indicators
(r1i, r5i, r10i, medri, meanri), _ = utils.acc_i2t2(last_sims)
logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
(r1i, r5i, r10i, medri, meanri))
(r1t, r5t, r10t, medrt, meanrt), _ = utils.acc_t2i2(last_sims)
logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
(r1t, r5t, r10t, medrt, meanrt))
currscore = (r1t + r5t + r10t + r1i + r5i + r10i)/6.0
all_score = "r1i:{} r5i:{} r10i:{} medri:{} meanri:{}\n r1t:{} r5t:{} r10t:{} medrt:{} meanrt:{}\n sum:{}\n ------\n".format(
r1i, r5i, r10i, medri, meanri, r1t, r5t, r10t, medrt, meanrt, currscore
)
print(all_score)