Skip to content

Commit

Permalink
good working state
Browse files Browse the repository at this point in the history
  • Loading branch information
rbgirshick committed Feb 6, 2015
1 parent eafc81a commit 3471eab
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 46 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pyc
20 changes: 20 additions & 0 deletions fast_rcnn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np

# Scales used in the SPP-net paper
SCALES = (480, 576, 688, 864, 1200)
# Minibatch size
BATCH_SIZE = 128
# Fraction of minibatch that is foreground labeled (class > 0)
FG_FRACTION = 0.25
# Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH)
FG_THRESH = 0.5
# Overlap threshold for a ROI to be considered background (class = 0 if
# overlap in [0.1, 0.5))
BG_THRESH_HI = 0.5
BG_THRESH_LO = 0.1
# Pixel mean values (BGR order) as a (1, 1, 3) array
PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]])
# Stride in input image pixels at ROI pooling level
FEAT_STRIDE = 16
# Max pixel size of a scaled input image
MAX_SIZE = 2000
67 changes: 34 additions & 33 deletions finetuning.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import numpy as np
import cv2
import matplotlib.pyplot as plt
import fast_rcnn_config as conf
from keyboard import keyboard

SCALES = (480, 576, 688, 864, 1200)
BATCH_SIZE = 128
FG_FRACTION = 0.25
FG_THRESH = 0.5
BG_THRESH_HI = 0.5
BG_THRESH_LO = 0.1
PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]])
FEAT_STRIDE = 16
MAX_SIZE = 2000

def sample_rois(labels, overlaps, rois, fg_rois_per_image, rois_per_image):
"""Generate a random sample of ROIs comprising foreground and background
examples.
Expand All @@ -30,13 +21,13 @@ def sample_rois(labels, overlaps, rois, fg_rois_per_image, rois_per_image):
rois (2d np array)
"""
# Select foreground ROIs
fg_inds = np.where(overlaps >= FG_THRESH)[0]
fg_inds = np.where(overlaps >= conf.FG_THRESH)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_image, fg_inds.size)
fg_inds = np.random.choice(fg_inds, size=fg_rois_per_this_image,
replace=False)
# Select background ROIs
bg_inds = np.where((overlaps < BG_THRESH_HI) &
(overlaps >= BG_THRESH_LO))[0]
bg_inds = np.where((overlaps < conf.BG_THRESH_HI) &
(overlaps >= conf.BG_THRESH_LO))[0]
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.size)
Expand All @@ -59,67 +50,77 @@ def get_image_blob(window_db, scale_inds, do_flip):
im_scale_factors = []
for i in xrange(num_images):
im = cv2.imread(window_db[i]['image'])
# if do_flip:
# im = im[:, ::-1, :]
if do_flip:
im = im[:, ::-1, :]
im = im.astype(np.float32, copy=False)
im -= PIXEL_MEANS
im -= conf.PIXEL_MEANS
im_shape = im.shape
im_size = np.min(im_shape[0:2])
im_size_big = np.max(im_shape[0:2])
target_size = SCALES[scale_inds[i]]
target_size = conf.SCALES[scale_inds[i]]
im_scale = float(target_size) / float(im_size)
# Prevent the biggest axis from being more than MAX_SIZE
if np.round(im_scale * im_size_big) > MAX_SIZE:
im_scale = float(MAX_SIZE) / float(im_size_big)
if np.round(im_scale * im_size_big) > conf.MAX_SIZE:
im_scale = float(conf.MAX_SIZE) / float(im_size_big)
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_LINEAR)
im_scale_factors.append(im_scale)
processed_ims.append(im)
max_shape = np.maximum(max_shape, im.shape)

blob = np.zeros((num_images, max_shape[0],
max_shape[1], max_shape[2]))
max_shape[1], max_shape[2]), dtype=np.float32)
for i in xrange(num_images):
im = processed_ims[i]
blob[i, 0:im.shape[0], 0:im.shape[1], :] = im
# Move channels (axis 3) to axis 1
# Axis order will become: (batch elem, channel, height, width)
channel_swap = (0, 3, 1, 2)
blob = blob.transpose(channel_swap)
return blob, im_scale_factors

def map_im_rois_to_feat_rois(im_rois, im_scale_factor):
feat_rois = np.round(im_rois * im_scale_factor / FEAT_STRIDE)
feat_rois = np.round(im_rois * im_scale_factor / conf.FEAT_STRIDE)
return feat_rois

def get_minibatch(window_db, random_flip=True):
def get_minibatch(window_db, random_flip=False):
# Decide to flip the entire batch or not
# do_flip = False if not random_flip else bool(np.random.randint(0, high=2))
do_flip = False if not random_flip else bool(np.random.randint(0, high=2))
assert(not do_flip)
num_images = len(window_db)
# Sample random scales to use for each image in this batch
random_scale_inds = np.random.randint(0, high=len(SCALES), size=num_images)
assert(BATCH_SIZE % num_images == 0), 'num_images must divide BATCH_SIZE'
rois_per_image = BATCH_SIZE / num_images
fg_rois_per_image = np.round(FG_FRACTION * rois_per_image)
random_scale_inds = \
np.random.randint(0, high=len(conf.SCALES), size=num_images)
assert(conf.BATCH_SIZE % num_images == 0), \
'num_images must divide BATCH_SIZE'
rois_per_image = conf.BATCH_SIZE / num_images
fg_rois_per_image = np.round(conf.FG_FRACTION * rois_per_image)
# Get the input blob, formatted for caffe
# Takes care of random scaling and flipping
do_flip = False
im_blob, im_scale_factors = get_image_blob(window_db,
random_scale_inds, do_flip)
# Now, build the region of interest and label blobs
rois_blob = np.zeros((0, 5), dtype=np.float32)
labels_blob = np.zeros((0), dtype=np.float32)
all_overlaps = []
for im_i in xrange(num_images):
# (labels, overlaps, x1, y1, x2, y2)
labels = window_db[im_i]['windows'][:, 0]
overlaps = window_db[im_i]['windows'][:, 1]
im_rois = window_db[im_i]['windows'][:, 2:]
# if do_flip:
# im_rois[:, (0, 2)] = window_db[im_i]['width'] - \
# im_rois[:, (2, 0)] - 1
if do_flip:
im_rois[:, (0, 2)] = window_db[im_i]['width'] - \
im_rois[:, (2, 0)] - 1
labels, overlaps, im_rois = sample_rois(labels, overlaps, im_rois,
fg_rois_per_image,
rois_per_image)
feat_rois = map_im_rois_to_feat_rois(im_rois, im_scale_factors[im_i])
# Assert various bounds
assert((feat_rois[:, 2] >= feat_rois[:, 0]).all())
assert((feat_rois[:, 3] >= feat_rois[:, 1]).all())
assert((feat_rois >= 0).all())
assert((feat_rois < np.max(im_blob.shape[2:4]) *
im_scale_factors[im_i] / conf.FEAT_STRIDE).all())
rois_blob_this_image = \
np.append(im_i * np.ones((feat_rois.shape[0], 1)), feat_rois,
axis=1)
Expand All @@ -134,9 +135,9 @@ def vis_minibatch(im_blob, rois_blob, labels_blob, overlaps):
for i in xrange(rois_blob.shape[0]):
rois = rois_blob[i, :]
im_ind = rois[0]
roi = rois[1:] * FEAT_STRIDE
roi = rois[1:] * conf.FEAT_STRIDE
im = im_blob[im_ind, :, :, :].transpose((1, 2, 0)).copy()
im += PIXEL_MEANS
im += conf.PIXEL_MEANS
im = im[:, :, (2, 1, 0)]
im = im.astype(np.uint8)
cls = labels_blob[i]
Expand Down
Loading

0 comments on commit 3471eab

Please sign in to comment.