-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdpsh.py
142 lines (123 loc) · 4.18 KB
/
dpsh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.optim as optim
import time
from torch.optim.lr_scheduler import CosineAnnealingLR
from models.model_loader import load_model
from loguru import logger
from models.dpsh_loss import DPSHLoss
from utils.evaluate import mean_average_precision
def train(
train_dataloader,
query_dataloader,
retrieval_dataloader,
arch,
code_length,
device,
eta,
lr,
max_iter,
topk,
evaluate_interval,
):
"""
Training model.
Args
train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
arch(str): CNN model name.
code_length(int): Hash code length.
device(torch.device): GPU or CPU.
eta(float): Hyper-parameter.
lr(float): Learning rate.
max_iter(int): Number of iterations.
topk(int): Calculate map of top k.
evaluate_interval(int): Evaluation interval.
Returns
checkpoint(dict): Checkpoint.
"""
# Create model, optimizer, criterion, scheduler
model = load_model(arch, code_length).to(device)
criterion = DPSHLoss(eta)
optimizer = optim.RMSprop(
model.parameters(),
lr=lr,
weight_decay=1e-5,
)
scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7)
# Initialization
N = len(train_dataloader.dataset)
U = torch.zeros(N, code_length).to(device)
train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
# Training
best_map = 0.0
iter_time = time.time()
for it in range(max_iter):
model.train()
running_loss = 0.
for data, targets, index in train_dataloader:
data, targets = data.to(device), targets.to(device)
optimizer.zero_grad()
S = (targets @ train_targets.t() > 0).float()
U_cnn = model(data)
U[index, :] = U_cnn.data
loss = criterion(U_cnn, U, S)
loss.backward()
optimizer.step()
running_loss += loss.item()
scheduler.step()
# Evaluate
if it % evaluate_interval == evaluate_interval-1:
iter_time = time.time() - iter_time
# Generate hash code and one-hot targets
query_code = generate_code(model, query_dataloader, code_length, device)
query_targets = query_dataloader.dataset.get_onehot_targets()
retrieval_code = generate_code(model, retrieval_dataloader, code_length, device)
retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets()
# Compute map
mAP = mean_average_precision(
query_code.to(device),
retrieval_code.to(device),
query_targets.to(device),
retrieval_targets.to(device),
device,
topk,
)
# Save checkpoint
if best_map < mAP:
best_map = mAP
checkpoint = {
'qB': query_code,
'qL': query_targets,
'rB': retrieval_code,
'rL': retrieval_targets,
'model': model.state_dict(),
'map': best_map,
}
logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
it+1,
max_iter,
running_loss,
mAP,
iter_time,
))
iter_time = time.time()
return checkpoint
def generate_code(model, dataloader, code_length, device):
"""
Generate hash code
Args
dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
code_length(int): Hash code length.
device(torch.device): Using gpu or cpu.
Returns
code(torch.Tensor, n*code_length): Hash code.
"""
model.eval()
with torch.no_grad():
N = len(dataloader.dataset)
code = torch.zeros([N, code_length])
for data, _, index in dataloader:
data = data.to(device)
hash_code = model(data)
code[index, :] = hash_code.sign().cpu()
model.train()
return code