Skip to content

Commit

Permalink
Add forwarding step for the triplet inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
liliangqi committed Nov 5, 2018
1 parent 3382d00 commit 5045580
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 10 deletions.
106 changes: 103 additions & 3 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# Author: Liangqi Li and Xinlei Chen
# Creating Date: Apr 1, 2018
# Latest rectified: Oct 27, 2018
# Latest rectified: Nov 5, 2018
# -----------------------------------------------------
import torch
import torch.nn as nn
Expand All @@ -14,7 +14,7 @@
from .resnet import MyResNet
from .densenet import DenseNet
from .strpn import STRPN
from utils.losses import oim_loss, smooth_l1_loss
from utils.losses import oim_loss, smooth_l1_loss, TripletLoss


class SIPN(nn.Module):
Expand Down Expand Up @@ -64,10 +64,110 @@ def __init__(self, net_name, dataset_name, pre_model=''):
self.reid_feat_net = nn.Linear(self.fc7_channels, self.reid_feat_dim)
self.init_linear_weight(False)

self.triplet_loss = TripletLoss()

def forward(self, im_data, gt_boxes, im_info, mode='gallery'):
if self.training:
# ###############################################################
# ========================Triplet Loss===========================
# ###############################################################
if isinstance(im_data, tuple):
assert isinstance(gt_boxes, tuple)
assert isinstance(im_info, tuple)

# Extract the feature of the query
q_im = im_data[0]
q_box = gt_boxes[0]
q_info = im_info[0]

q_box = torch.cat((torch.zeros(1, 1).cuda(), q_box[:, :4]), 1)
q_net_conv = self.head(q_im)
q_pool_feat = self.strpn(q_net_conv, q_box, q_info, 'query')
if self.net_name == 'vgg16':
q_pool_feat = q_pool_feat.view(q_pool_feat.size(0), -1)
q_fc7 = self.tail(q_pool_feat)
else:
q_fc7 = self.tail(q_pool_feat).mean(3).mean(2)
q_reid_feat = self.reid_feat_net(q_fc7)

# Extract the feature of the positive gallery
p_im = im_data[1]
p_boxes = gt_boxes[1]
p_info = im_info[1]

p_net_conv = self.head(p_im)
p_pool_feat, p_tr_feat, p_rpn_loss, p_label, p_bbox_info = \
self.strpn(p_net_conv, p_boxes, p_info)
if self.net_name == 'vgg16':
p_pool_feat = p_pool_feat.view(p_pool_feat.size(0), -1)
p_fc7 = self.tail(p_pool_feat)
else:
p_fc7 = self.tail(p_pool_feat).mean(3).mean(2)
p_cls_score = self.cls_score_net(p_fc7)
p_bbox_pred = self.bbox_pred_net(p_fc7)
p_reid_feat = self.reid_feat_net(p_fc7)

p_det_label, p_pid_label = p_label
p_det_label = p_det_label.view(-1)
p_cls_loss = func.cross_entropy(
p_cls_score.view(-1, 2), p_det_label)
p_bbox_loss = smooth_l1_loss(p_bbox_pred, p_bbox_info)
p_rpn_cls_loss, p_rpn_box_loss = p_rpn_loss

# Extract the feature of the negative gallery
n_im = im_data[2]
n_boxes = gt_boxes[2]
n_info = im_info[2]

n_net_conv = self.head(n_im)
n_pool_feat, n_tr_feat, n_rpn_loss, n_label, n_bbox_info = \
self.strpn(n_net_conv, n_boxes, n_info)
if self.net_name == 'vgg16':
n_pool_feat = n_pool_feat.view(n_pool_feat.size(0), -1)
n_fc7 = self.tail(n_pool_feat)
else:
n_fc7 = self.tail(n_pool_feat).mean(3).mean(2)
n_cls_score = self.cls_score_net(n_fc7)
n_bbox_pred = self.bbox_pred_net(n_fc7)
n_reid_feat = self.reid_feat_net(n_fc7)

n_det_label, n_pid_label = n_label
n_det_label = n_det_label.view(-1)
n_cls_loss = func.cross_entropy(
n_cls_score.view(-1, 2), n_det_label)
n_bbox_loss = smooth_l1_loss(n_bbox_pred, n_bbox_info)
n_rpn_cls_loss, n_rpn_box_loss = n_rpn_loss

# Compute loss
rpn_cls_loss = p_rpn_cls_loss + n_rpn_cls_loss
rpn_box_loss = p_rpn_box_loss + n_rpn_box_loss
cls_loss = p_cls_loss + n_cls_loss
bbox_loss = p_bbox_loss + n_bbox_loss

query_pid = int(gt_boxes[0][:, -1].item())
p_mask = (p_pid_label.squeeze() != self.num_pid).nonzero(
).squeeze().view(-1)
p_pid_label_drop = p_pid_label[p_mask]
p_reid_feat_drop = p_reid_feat[p_mask]
n_mask = (n_pid_label.squeeze() != self.num_pid).nonzero(
).squeeze().view(-1)
n_pid_label_drop = n_pid_label[n_mask]
n_reid_feat_drop = n_reid_feat[n_mask]

tri_label = torch.cat(
(p_pid_label_drop, n_pid_label_drop)).squeeze()
tri_feat = torch.cat((p_reid_feat_drop, n_reid_feat_drop), 0)
reid_loss = self.triplet_loss(
q_reid_feat, query_pid, tri_feat, tri_label, mode='hard')

return rpn_cls_loss, rpn_box_loss, cls_loss, bbox_loss,\
reid_loss

# ###############################################################
# ###############################################################

net_conv = self.head(im_data)
# returned parameters contain 3 tuples here
# Returned parameters contain 3 tuples here
pooled_feat, trans_feat, rpn_loss, label, bbox_info = self.strpn(
net_conv, gt_boxes, im_info)
if self.net_name == 'vgg16':
Expand Down
28 changes: 21 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# Author: Liangqi Li
# Creating Date: Mar 31, 2018
# Latest rectified: Oct 25, 2018
# Latest rectified: Nov 5, 2018
# -----------------------------------------------------
import os
import time
Expand All @@ -16,7 +16,8 @@

from utils.utils import clock_non_return, AverageMeter
from utils.logger import TensorBoardLogger
from dataset.sipn_dataset import SIPNDataset
from dataset.sipn_dataset import SIPNDataset, sipn_fn, \
PersonSearchTripletSampler, PersonSearchTripletFn
import dataset.sipn_transforms as sipn_transforms
from models.model import SIPN

Expand Down Expand Up @@ -51,10 +52,17 @@ def train_model(dataloader, net, optimizer, epoch):
config = yaml.load(f)

for iter_idx, data in enumerate(dataloader):
im, (gt_boxes, im_info) = data
im = im.to(device)
gt_boxes = gt_boxes.squeeze(0).to(device)
im_info = im_info.numpy().squeeze(0)
im, gt_boxes, im_info = data

if isinstance(im, tuple):
assert isinstance(gt_boxes, tuple)
assert isinstance(im_info, tuple)
im = tuple([x.to(device) for x in im])
gt_boxes = tuple([x.to(device) for x in gt_boxes])
else:
im = im.to(device)
gt_boxes = gt_boxes.squeeze(0).to(device)
im_info = im_info.ravel()

# Forward and backward
losses = net(im, gt_boxes, im_info)
Expand Down Expand Up @@ -112,6 +120,8 @@ def main():
opt = parse_args()
global device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
torch.manual_seed(1024)

save_dir = os.path.join(opt.out_dir, opt.dataset_name)
print('Trained models will be saved to {}\n'.format(
Expand Down Expand Up @@ -143,7 +153,10 @@ def main():

# Load the dataset
dataset = SIPNDataset(opt.data_dir, opt.dataset_name, 'train', transform)
dataloader = DataLoader(dataset, shuffle=True, num_workers=8)
sampler = PersonSearchTripletSampler(dataset)
collate_fn = PersonSearchTripletFn(dataset, sampler.batch_pids)
dataloader = DataLoader(
dataset, batch_sampler=sampler, collate_fn=collate_fn)

# Choose parameters to be updated during training
lr = opt.lr
Expand Down Expand Up @@ -190,6 +203,7 @@ def main():

train_model(dataloader, model, optimizer, epoch)
scheduler.step()
collate_fn.called_times = 0

epoch_end = time.time()
print('\nEntire epoch time cost: {:.2f} hours\n'.format(
Expand Down

0 comments on commit 5045580

Please sign in to comment.