Skip to content

Commit

Permalink
gradient clip
Browse files Browse the repository at this point in the history
  • Loading branch information
longcw committed Feb 12, 2017
1 parent bef72e3 commit b5772e6
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 28 deletions.
5 changes: 3 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

def test():
import os
im_file = 'demo/004545.jpg'
# im_file = 'demo/004545.jpg'
im_file = 'data/VOCdevkit2007/VOC2007/JPEGImages/009936.jpg'
im_file = '/media/longc/Data/data/2DMOT2015/test/ETH-Crossing/img1/000100.jpg'
# im_file = '/media/longc/Data/data/2DMOT2015/test/ETH-Crossing/img1/000100.jpg'
image = cv2.imread(im_file)

# model_file = '/media/longc/Data/models/VGGnet_fast_rcnn_iter_70000.h5'
model_file = '/media/longc/Data/models/faster_rcnn_pytorch/faster_rcnn_30000.h5'
# model_file = '/media/longc/Data/models/faster_rcnn_pytorch2/faster_rcnn_2000.h5'
detector = FasterRCNN()
network.load_net(model_file, detector)
detector.cuda()
Expand Down
2 changes: 1 addition & 1 deletion experiments/cfgs/faster_rcnn_end2end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TRAIN:
DISPLAY: 10
SNAPSHOT_ITERS: 5000
HAS_RPN: True
LEARNING_RATE: 0.0001
LEARNING_RATE: 0.001
MOMENTUM: 0.9
GAMMA: 0.1
STEPSIZE: 60000
Expand Down
12 changes: 6 additions & 6 deletions faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def loss(self):
# print self.loss_box
# print self.rpn.cross_entropy
# print self.rpn.loss_box
return self.cross_entropy + self.loss_box * 10 + self.rpn.loss
return self.cross_entropy + self.loss_box * 10

def forward(self, im_data, im_info, gt_boxes=None, dontcare_areas=None):
features, rois = self.rpn(im_data, im_info, gt_boxes, dontcare_areas)
Expand Down Expand Up @@ -258,18 +258,18 @@ def build_loss(self, cls_score, bbox_pred, roi_data):
fg_cnt = torch.sum(label.data.ne(0))
bg_cnt = label.data.numel() - fg_cnt

# ce_weights = torch.ones(cls_score.size()[1])
# ce_weights[0] = float(fg_cnt) / bg_cnt / 3.0
# # ce_weights[0] = 1./50
# ce_weights = ce_weights.cuda()
ce_weights = torch.ones(cls_score.size()[1])
ce_weights[0] = float(fg_cnt) / bg_cnt
# ce_weights[0] = 1./50
ce_weights = ce_weights.cuda()

maxv, predict = cls_score.data.max(1)
self.tp = torch.sum(predict[:fg_cnt].eq(label.data[:fg_cnt]))
self.tf = torch.sum(predict[fg_cnt:].eq(label.data[fg_cnt:]))
self.fg_cnt = fg_cnt
self.bg_cnt = bg_cnt
# print predict
cross_entropy = F.cross_entropy(cls_score, label)
cross_entropy = F.cross_entropy(cls_score, label, weight=ce_weights)
# print cross_entropy

# bounding box regression L1 loss
Expand Down
29 changes: 24 additions & 5 deletions faster_rcnn/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,27 @@ def set_trainable(model, requires_grad):


def weights_normal_init(model, dev=0.01):
for m in model.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, dev)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, dev)
if isinstance(model, list):
for m in model:
weights_normal_init(m, dev)
else:
for m in model.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, dev)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, dev)


def clip_gradient(model, clip_norm):
"""Computes a gradient clipping coefficient based on gradient norm."""
totalnorm = 0
for p in model.parameters():
if p.requires_grad:
modulenorm = p.grad.data.norm()
totalnorm += modulenorm ** 2
totalnorm = np.sqrt(totalnorm)

norm = clip_norm / max(totalnorm, clip_norm)
for p in model.parameters():
if p.requires_grad:
p.grad.mul_(norm)
35 changes: 21 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
data_layer = RoIDataLayer(roidb, imdb.num_classes)

net = FasterRCNN(classes=imdb.classes)
network.weights_normal_init(net, dev=0.01)
net.rpn.features.load_from_npy_file(pretrained_model)
# model_file = '/media/longc/Data/models/VGGnet_fast_rcnn_iter_70000.h5'
# network.weights_normal_init(net, dev=0.01)
# net.rpn.features.load_from_npy_file(pretrained_model)
model_file = '/media/longc/Data/models/VGGnet_fast_rcnn_iter_70000.h5'
# model_file = '/media/longc/Data/models/faster_rcnn_pytorch/faster_rcnn_10000.h5'
# network.load_net(model_file, net)
network.load_net(model_file, net)
network.weights_normal_init([net.bbox_fc, net.score_fc, net.fc6, net.fc7], dev=0.01)

# net = net.rpn

Expand All @@ -46,10 +47,12 @@
params = list(net.parameters())
for p in params:
print p.size()
# optimizer = torch.optim.Adam(params[8:], lr=lr)
optimizer = torch.optim.SGD(params[8:-8], lr=lr, momentum=0.9, weight_decay=0.0005)
train_all = False
target_net = net.rpn
# optimizer = torch.optim.Adam(params[-8:], lr=lr)
optimizer = torch.optim.SGD(params[-8:], lr=lr, momentum=0.9, weight_decay=0.0005)
train_all = True
# target_net = net.rpn
target_net = net
network.set_trainable(net.rpn, False)

if not os.path.exists(output_dir):
os.mkdir(output_dir)
Expand Down Expand Up @@ -85,6 +88,7 @@
# backward
optimizer.zero_grad()
loss.backward()
network.clip_gradient(target_net, 10.)
optimizer.step()

if step % log_interval == 0:
Expand All @@ -103,12 +107,15 @@
tp, tf, fg, bg = 0., 0., 0, 0
t.tic()

if step % 20000 == 0 and step > 0:
if step % 10000 == 0 and step > 0:
save_name = os.path.join(output_dir, 'faster_rcnn_{}.h5'.format(step))
network.save_net(save_name, net)
print('save model: {}'.format(save_name))
lr /= 3
optimizer = torch.optim.SGD(params[8:], lr=lr, momentum=0.9, weight_decay=0.0005)
# optimizer = torch.optim.Adam(params[8:], lr=lr)
train_all = True
target_net = net
# lr /= 10
# optimizer = torch.optim.SGD(params[-8:], lr=lr, momentum=0.9, weight_decay=0.0005)
# if step >= 20000:
# lr /= 3
# optimizer = torch.optim.SGD(params[8:], lr=lr, momentum=0.9, weight_decay=0.0005)
# # optimizer = torch.optim.Adam(params[8:], lr=lr)
# train_all = True
# target_net = net

0 comments on commit b5772e6

Please sign in to comment.