-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrank_list.py
124 lines (109 loc) · 4.22 KB
/
rank_list.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import scipy.io
import torch
import numpy as np
import os
from torchvision import datasets
import matplotlib
#matplotlib.use('agg')
import matplotlib.pyplot as plt
#######################################################################
# Evaluate
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
# parser = argparse.ArgumentParser(description='Demo')
# parser.add_argument('--query_index', default=0, type=int, help='test_image_index')
# parser.add_argument('--test_dir',default='/home/wangtyu/datasets/University-Release/test',type=str, help='./test_data')
# opts = parser.parse_args()
gallery_name = 'gallery_satellite'
query_name = 'query_drone'
# gallery_name = 'gallery_drone'
# query_name = 'query_satellite'
# data_dir = opts.test_dir
data_dir = '/home/wangtyu/datasets/University-Release/test'
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ) for x in [gallery_name, query_name]}
#####################################################################
#Show result
def imshow(path, title=None):
"""Imshow for Tensor."""
im = plt.imread(path)
plt.imshow(im)
if title is not None:
plt.title(title)
plt.pause(0.1) # pause a bit so that plots are updated
######################################################################
result = scipy.io.loadmat('pytorch_result.mat')
query_feature = torch.FloatTensor(result['query_f'])
query_label = result['query_label'][0]
gallery_feature = torch.FloatTensor(result['gallery_f'])
gallery_label = result['gallery_label'][0]
multi = os.path.isfile('multi_query.mat')
if multi:
m_result = scipy.io.loadmat('multi_query.mat')
mquery_feature = torch.FloatTensor(m_result['mquery_f'])
mquery_cam = m_result['mquery_cam'][0]
mquery_label = m_result['mquery_label'][0]
mquery_feature = mquery_feature.cuda()
query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()
#######################################################################
# sort the images
def sort_img(qf, ql, gf, gl):
query = qf.view(-1,1)
# print(query.shape)
score = torch.mm(gf,query)
score = score.squeeze(1).cpu()
score = score.numpy()
# predict index
index = np.argsort(score) #from small to large
index = index[::-1]
# index = index[0:2000]
# good index
query_index = np.argwhere(gl==ql)
#good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
junk_index = np.argwhere(gl==-1)
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
return index
query_path = '/home/wangtyu/datasets/University-Release/test/query_drone/0561/image-24.jpeg'
i = np.where(result['query_path']==query_path)
index = sort_img(query_feature[i],query_label[i],gallery_feature,gallery_label)
R5_index = index[0:5]
R5_path = result['gallery_path'][R5_index]
print(R5_path)
########################################################################
# Visualize the rank result
# query_path, _ = image_datasets[query_name].imgs[i]
# query_label = query_label[i]
# print(query_path)
# print('Top 10 images are as follow:')
# save_folder = 'image_show/%02d'%opts.query_index
# if not os.path.isdir(save_folder):
# os.mkdir(save_folder)
# os.system('cp %s %s/query.jpg'%(query_path, save_folder))
# try: # Visualize Ranking Result
# # Graphical User Interface is needed
# fig = plt.figure(figsize=(16,4))
# ax = plt.subplot(1,11,1)
# ax.axis('off')
# imshow(query_path,'query')
# for i in range(10):
# ax = plt.subplot(1,11,i+2)
# ax.axis('off')
# img_path, _ = image_datasets[gallery_name].imgs[index[i]]
# label = gallery_label[index[i]]
# print(label)
# imshow(img_path)
# os.system('cp %s %s/s%02d.jpg'%(img_path, save_folder, i))
# if label == query_label:
# ax.set_title('%d'%(i+1), color='green')
# else:
# ax.set_title('%d'%(i+1), color='red')
# print(img_path)
# #plt.pause(100) # pause a bit so that plots are updated
# except RuntimeError:
# for i in range(10):
# img_path = image_datasets.imgs[index[i]]
# print(img_path[0])
# print('If you want to see the visualization of the ranking result, graphical user interface is needed.')
#
# fig.savefig("show.png")