Skip to content

Commit

Permalink
modified: experiments/script_faster_rcnn_demo.m
Browse files Browse the repository at this point in the history
	modified:   functions/fast_rcnn/fast_rcnn_conv_feat_detect.m
	modified:   functions/rpn/proposal_im_detect.m
	modified:   utils/prep_im_for_blob.m
	modified:   utils/showboxes.m
  • Loading branch information
ShaoqingRen committed Aug 11, 2015
1 parent f48c799 commit 4354a00
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 26 deletions.
12 changes: 10 additions & 2 deletions experiments/script_faster_rcnn_demo.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ function script_faster_rcnn_demo()
opts.test_scales = 600;

%% -------------------- INIT_MODEL --------------------
model_dir = fullfile(pwd, 'output', 'faster_rcnn_final', 'faster_rcnn_VOC2007_vgg_16layers');
model_dir = fullfile(pwd, 'output', 'faster_rcnn_final', 'faster_rcnn_VOC0712_vgg_16layers');
proposal_detection_model = load_proposal_detection_model(model_dir);

proposal_detection_model.conf_proposal.test_scales = opts.test_scales;
proposal_detection_model.conf_detection.test_scales = opts.test_scales;
if opts.use_gpu
proposal_detection_model.conf_proposal.image_means = gpuArray(proposal_detection_model.conf_proposal.image_means);
proposal_detection_model.conf_detection.image_means = gpuArray(proposal_detection_model.conf_detection.image_means);
end

% caffe.init_log(fullfile(pwd, 'caffe_log'));
% proposal net
Expand All @@ -41,6 +45,10 @@ function script_faster_rcnn_demo()
%% -------------------- TESTING --------------------
im = imread(fullfile(pwd, '004545.jpg'));

if opts.use_gpu
im = gpuArray(im);
end

for j = 1:10
% test proposal
th = tic();
Expand All @@ -54,7 +62,7 @@ function script_faster_rcnn_demo()
th = tic();
if proposal_detection_model.is_share_feature
[boxes, scores] = fast_rcnn_conv_feat_detect(proposal_detection_model.conf_detection, fast_rcnn_net, im, ...
rpn_net.blobs(proposal_detection_model.last_shared_output_blob_name).get_data(), ...
rpn_net.blobs(proposal_detection_model.last_shared_output_blob_name), ...
aboxes(:, 1:4), opts.after_nms_topN);
else
[boxes, scores] = fast_rcnn_im_detect(proposal_detection_model.conf_detection, fast_rcnn_net, im, ...
Expand Down
22 changes: 7 additions & 15 deletions functions/fast_rcnn/fast_rcnn_conv_feat_detect.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function [pred_boxes, scores] = fast_rcnn_conv_feat_detect(conf, caffe_net, im, conv_feat, boxes, max_rois_num_in_gpu)
% [pred_boxes, scores] = fast_rcnn_conv_feat_detect(conf, caffe_net, im, conv_feat, boxes, max_rois_num_in_gpu)
function [pred_boxes, scores] = fast_rcnn_conv_feat_detect(conf, caffe_net, im, conv_feat_blob, boxes, max_rois_num_in_gpu)
% [pred_boxes, scores] = fast_rcnn_conv_feat_detect(conf, caffe_net, im, conv_feat_blob, boxes, max_rois_num_in_gpu)
% --------------------------------------------------------
% Fast R-CNN
% Reimplementation based on Python Fast R-CNN (https://github.com/rbgirshick/fast-rcnn)
Expand All @@ -9,19 +9,14 @@

[rois_blob, ~] = get_blobs(conf, im, boxes);

% When mapping from image ROIs to feature map ROIs, there's some aliasing
% (some distinct image ROIs get mapped to the same feature ROI).
% Here, we identify duplicate feature ROIs, so we only compute features
% on the unique subset.
[~, index, inv_index] = unique(rois_blob, 'rows');
rois_blob = rois_blob(index, :);
boxes = boxes(index, :);

% permute data into caffe c++ memory, thus [num, channels, height, width]
rois_blob = rois_blob - 1; % to c's index (start from 0)
rois_blob = permute(rois_blob, [3, 4, 2, 1]);
rois_blob = single(rois_blob);

% set conv feature map as 'data'
caffe_net.blobs('data').copy_data_from(conv_feat_blob);

total_rois = size(rois_blob, 4);
total_scores = cell(ceil(total_rois / max_rois_num_in_gpu), 1);
total_box_deltas = cell(ceil(total_rois / max_rois_num_in_gpu), 1);
Expand All @@ -31,7 +26,8 @@
sub_ind_end = min(total_rois, i * max_rois_num_in_gpu);
sub_rois_blob = rois_blob(:, :, :, sub_ind_start:sub_ind_end);

net_inputs = {conv_feat, sub_rois_blob};
% only set rois blob here
net_inputs = {[], sub_rois_blob};

% Reshape net's input blobs
caffe_net.reshape_as_input(net_inputs);
Expand Down Expand Up @@ -62,10 +58,6 @@

pred_boxes = fast_rcnn_bbox_transform_inv(boxes, box_deltas);
pred_boxes = clip_boxes(pred_boxes, size(im, 2), size(im, 1));

% Map scores and predictions back to the original set of boxes
scores = scores(inv_index, :);
pred_boxes = pred_boxes(inv_index, :);

% remove scores and boxes for back-ground
pred_boxes = pred_boxes(:, 5:end);
Expand Down
1 change: 1 addition & 0 deletions functions/rpn/proposal_im_detect.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
% Licensed under The MIT License [see LICENSE for details]
% --------------------------------------------------------

im = single(im);
[im_blob, im_scales] = get_image_blob(conf, im);
im_size = size(im);
scaled_im_size = round(im_size * im_scales);
Expand Down
36 changes: 27 additions & 9 deletions utils/prep_im_for_blob.m
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
function [im, im_scale] = prep_im_for_blob(im, im_means, target_size, max_size)
im = single(im);
try
im = bsxfun(@minus, im, im_means);
catch
im_means = imresize(im_means, [size(im, 1), size(im, 2)], 'bilinear', 'antialiasing', false);
im = bsxfun(@minus, im, im_means);

if ~isa(im, 'gpuArray')
try
im = bsxfun(@minus, im, im_means);
catch
im_means = imresize(im_means, [size(im, 1), size(im, 2)], 'bilinear', 'antialiasing', false);
im = bsxfun(@minus, im, im_means);
end
im_scale = prep_im_for_blob_size(size(im), target_size, max_size);

target_size = round([size(im, 1), size(im, 2)] * im_scale);
im = imresize(im, target_size, 'bilinear', 'antialiasing', false);
else
% for im as gpuArray
try
im = bsxfun(@minus, im, im_means);
catch
im_means_scale = max(double(size(im, 1)) / size(im_means, 1), double(size(im, 2)) / size(im_means, 2));
im_means = imresize(im_means, im_means_scale);
y_start = floor((size(im_means, 1) - size(im, 1)) / 2) + 1;
x_start = floor((size(im_means, 2) - size(im, 2)) / 2) + 1;
im_means = im_means(y_start:(y_start+size(im, 1)-1), x_start:(x_start+size(im, 2)-1));
im = bsxfun(@minus, im, im_means);
end

im_scale = prep_im_for_blob_size(size(im), target_size, max_size);
im = imresize(im, im_scale);
end
im_scale = prep_im_for_blob_size(size(im), target_size, max_size);

target_size = round([size(im, 1), size(im, 2)] * im_scale);
im = imresize(im, target_size, 'bilinear', 'antialiasing', false);
end
3 changes: 3 additions & 0 deletions utils/showboxes.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ function showboxes(im, boxes, legends, color_conf)
% -------------------------------------------------------

fix_width = 800;
if isa(im, 'gpuArray')
im = gather(im);
end
imsz = size(im);
scale = fix_width / imsz(2);
im = imresize(im, scale);
Expand Down

0 comments on commit 4354a00

Please sign in to comment.