-
Notifications
You must be signed in to change notification settings - Fork 0
/
minibatch.py
executable file
·98 lines (82 loc) · 3.48 KB
/
minibatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick and Xinlei Chen
# --------------------------------------------------------
"""Compute minibatch blobs for training a Fast R-CNN network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import numpy.random as npr
from scipy.misc import imread
from model.utils.blob import prep_im_for_blob, im_list_to_blob
import pdb
from model.utils.config import cfg
def get_minibatch(roidb, num_classes, target_size576=1, target_size600=1, target_size800=1):
"""Given a roidb, construct a minibatch sampled from it."""
num_images = len(roidb)
# Sample random scales to use for each image in this batch
random_scale_inds = npr.randint(0, high=len(cfg.TRAIN.SCALES),
size=num_images)
# print('train scales is: ', cfg.TRAIN.SCALES, num_classes)
assert(cfg.TRAIN.BATCH_SIZE % num_images == 0), \
'num_images ({}) must divide BATCH_SIZE ({})'. \
format(num_images, cfg.TRAIN.BATCH_SIZE)
# Get the input image blob, formatted for caffe
im_blob, im_scales = _get_image_blob(roidb, random_scale_inds, target_size576, target_size600, target_size800)
blobs = {'data': im_blob}
assert len(im_scales) == 1, "Single batch only"
assert len(roidb) == 1, "Single batch only"
# gt boxes: (x1, y1, x2, y2, cls)
if cfg.TRAIN.USE_ALL_GT:
# Include all ground truth boxes
gt_inds = np.where(roidb[0]['gt_classes'] != 0)[0]
else:
# For the COCO ground truth boxes, exclude the ones that are ''iscrowd''
gt_inds = np.where(roidb[0]['gt_classes'] != 0 & np.all(roidb[0]['gt_overlaps'].toarray() > -1.0, axis=1))[0]
gt_boxes = np.empty((len(gt_inds), 5), dtype=np.float32)
gt_boxes[:, 0:4] = roidb[0]['boxes'][gt_inds, :] * im_scales[0]
gt_boxes[:, 4] = roidb[0]['gt_classes'][gt_inds]
blobs['gt_boxes'] = gt_boxes
blobs['im_info'] = np.array(
[[im_blob.shape[1], im_blob.shape[2], im_scales[0]]],
dtype=np.float32)
blobs['img_id'] = roidb[0]['img_id']
return blobs
def _get_image_blob(roidb, scale_inds, target_size576=1,target_size600=1, target_size800=1):
"""Builds an input blob from the images in the roidb at the specified
scales.
"""
num_images = len(roidb)
processed_ims = []
im_scales = []
for i in range(num_images):
#im = cv2.imread(roidb[i]['image'])
im = imread(roidb[i]['image'])
if len(im.shape) == 2:
im = im[:,:,np.newaxis]
im = np.concatenate((im,im,im), axis=2)
# flip the channel, since the original one using cv2
# rgb -> bgr
im = im[:,:,::-1]
if roidb[i]['flipped']:
im = im[:, ::-1, :]
target_size = cfg.TRAIN.SCALES[scale_inds[i]]
if cfg.random_resize:
if target_size == 600 and cfg.universal:
target_size = target_size*target_size600
elif target_size >= 780 and cfg.universal:
target_size = target_size*target_size800
elif target_size == 576 and cfg.universal:
target_size = int(target_size576)
#print('target_size',target_size, cfg.size, cfg.TRAIN.MAX_SIZE)
#print('cfg.TRAIN.SCALES', cfg.TRAIN.SCALES)
im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size,
cfg.TRAIN.MAX_SIZE)
im_scales.append(im_scale)
processed_ims.append(im)
# Create a blob to hold the input images
blob = im_list_to_blob(processed_ims)
return blob, im_scales