forked from zhangliliang/RPN_BF
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add files which support the training for RPN on Caltech dataset
- Loading branch information
1 parent
47e8fe2
commit 8f50c86
Showing
18 changed files
with
1,711 additions
and
16 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
function dataset = caltech_test(dataset, usage) | ||
|
||
switch usage | ||
case {'train'} | ||
dataset.imdb_train = { imdb_from_caltech('./datasets/caltech', 'test', true) }; | ||
dataset.roidb_train = cellfun(@(x) x.roidb_func(x, true), dataset.imdb_train, 'UniformOutput', false); | ||
case {'test'} | ||
dataset.imdb_test = imdb_from_caltech('./datasets/caltech', 'test', false) ; | ||
dataset.roidb_test = dataset.imdb_test.roidb_func(dataset.imdb_test, false); | ||
otherwise | ||
error('usage = ''train'' or ''test'''); | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
function dataset = caltech_trainval(dataset, usage) | ||
|
||
switch usage | ||
case {'train'} | ||
dataset.imdb_train = { imdb_from_caltech('./datasets/caltech', 'train', false) }; | ||
dataset.roidb_train = cellfun(@(x) x.roidb_func(x, false), dataset.imdb_train, 'UniformOutput', false); | ||
case {'test'} | ||
dataset.imdb_test = imdb_from_caltech('./datasets/caltech', 'train', false) ; | ||
dataset.roidb_test = dataset.imdb_test.roidb_func(dataset.imdb_test, false); | ||
otherwise | ||
error('usage = ''train'' or ''test'''); | ||
end | ||
|
||
end |
100 changes: 100 additions & 0 deletions
100
experiments/+Faster_RCNN_Train/do_proposal_test_caltech.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
function aboxes = do_proposal_test_caltech(conf, model_stage, imdb, roidb, cache_name, method_name) | ||
aboxes = proposal_test_caltech(conf, imdb, ... | ||
'net_def_file', model_stage.test_net_def_file, ... | ||
'net_file', model_stage.output_model_file, ... | ||
'cache_name', model_stage.cache_name); | ||
|
||
fprintf('Doing nms ... '); | ||
aboxes = boxes_filter(aboxes, model_stage.nms.per_nms_topN, model_stage.nms.nms_overlap_thres, model_stage.nms.after_nms_topN, conf.use_gpu); | ||
|
||
% eval the gt recall | ||
gt_num = 0; | ||
gt_re_num = 0; | ||
for i = 1:length(roidb.rois) | ||
% gts = roidb.rois(i).boxes; | ||
gts = roidb.rois(i).boxes(roidb.rois(i).ignores~=1, :); | ||
if ~isempty(gts) | ||
rois = aboxes{i}(:, 1:4); | ||
max_ols = max(boxoverlap(rois, gts)); | ||
gt_num = gt_num + size(gts, 1); | ||
gt_re_num = gt_re_num + sum(max_ols >= 0.5); | ||
end | ||
end | ||
fprintf('gt recall rate = %.4f\n', gt_re_num / gt_num); | ||
|
||
fprintf('Preparing the results for Caltech evaluation ...'); | ||
cache_dir = fullfile(pwd, 'output', conf.exp_name, 'rpn_cachedir', cache_name); | ||
res_boxes = aboxes; | ||
mkdir_if_missing(fullfile(cache_dir, method_name)); | ||
% remove all the former results | ||
DIRS=dir(fullfile(fullfile(cache_dir, method_name))); | ||
n=length(DIRS); | ||
for i=1:n | ||
if (DIRS(i).isdir && ~strcmp(DIRS(i).name,'.') && ~strcmp(DIRS(i).name,'..') ) | ||
rmdir(fullfile(cache_dir, method_name ,DIRS(i).name),'s'); | ||
end | ||
end | ||
|
||
assert(length(imdb.image_ids) == size(res_boxes, 1)); | ||
for i = 1:size(res_boxes, 1) | ||
if ~isempty(res_boxes{i}) | ||
sstr = strsplit(imdb.image_ids{i}, '_'); | ||
mkdir_if_missing(fullfile(cache_dir, method_name, sstr{1})); | ||
fid = fopen(fullfile(cache_dir, method_name, sstr{1}, [sstr{2} '.txt']), 'a'); | ||
% transform [x1 y1 x2 y2] to [x1 y1 x2-x1 y2-y1] | ||
res_boxes{i}(:, 3) = res_boxes{i}(:, 3) - res_boxes{i}(:, 1); | ||
res_boxes{i}(:, 4) = res_boxes{i}(:, 4) - res_boxes{i}(:, 2); | ||
for j = 1:size(res_boxes{i}, 1) | ||
fprintf(fid, '%d,%f,%f,%f,%f,%f\n', str2double(sstr{3}(2:end))+1, res_boxes{i}(j, :)); | ||
end | ||
fclose(fid); | ||
end | ||
end | ||
fprintf('Done.'); | ||
|
||
% copy results to eval folder and eval to get figure. | ||
folder1 = fullfile(pwd, 'output', conf.exp_name, 'rpn_cachedir', cache_name, method_name); | ||
folder2 = fullfile(pwd, 'external', 'code3.2.1', 'data-USA', 'res', method_name); | ||
copyfile(folder1, folder2); | ||
tmp_dir = pwd; | ||
cd(fullfile(pwd, 'external', 'code3.2.1')); | ||
dbEval_RPNBF; | ||
cd(tmp_dir); | ||
end | ||
|
||
function aboxes = boxes_filter(aboxes, per_nms_topN, nms_overlap_thres, after_nms_topN, use_gpu) | ||
% to speed up nms | ||
if per_nms_topN > 0 | ||
aboxes = cellfun(@(x) x(1:min(size(x, 1), per_nms_topN), :), aboxes, 'UniformOutput', false); | ||
end | ||
% do nms | ||
if nms_overlap_thres > 0 && nms_overlap_thres < 1 | ||
if 0 | ||
for i = 1:length(aboxes) | ||
tic_toc_print('weighted ave nms: %d / %d \n', i, length(aboxes)); | ||
aboxes{i} = get_keep_boxes(aboxes{i}, 0, nms_overlap_thres, 0.7); | ||
end | ||
else | ||
if use_gpu | ||
for i = 1:length(aboxes) | ||
tic_toc_print('nms: %d / %d \n', i, length(aboxes)); | ||
aboxes{i} = aboxes{i}(nms(aboxes{i}, nms_overlap_thres, use_gpu), :); | ||
end | ||
else | ||
parfor i = 1:length(aboxes) | ||
aboxes{i} = aboxes{i}(nms(aboxes{i}, nms_overlap_thres), :); | ||
end | ||
end | ||
end | ||
end | ||
aver_boxes_num = mean(cellfun(@(x) size(x, 1), aboxes, 'UniformOutput', true)); | ||
fprintf('aver_boxes_num = %d, select top %d\n', round(aver_boxes_num), after_nms_topN); | ||
if after_nms_topN > 0 | ||
aboxes = cellfun(@(x) x(1:min(size(x, 1), after_nms_topN), :), aboxes, 'UniformOutput', false); | ||
end | ||
end | ||
% | ||
% function regions = make_roidb_regions(aboxes, images) | ||
% regions.boxes = aboxes; | ||
% regions.images = images; | ||
% end |
15 changes: 15 additions & 0 deletions
15
experiments/+Faster_RCNN_Train/do_proposal_train_caltech.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
function model_stage = do_proposal_train_caltech(conf, dataset, model_stage, do_val) | ||
if ~do_val | ||
dataset.imdb_test = struct(); | ||
dataset.roidb_test = struct(); | ||
end | ||
|
||
model_stage.output_model_file = proposal_train_caltech(conf, dataset.imdb_train, dataset.roidb_train, ... | ||
'do_val', do_val, ... | ||
'imdb_val', dataset.imdb_test, ... | ||
'roidb_val', dataset.roidb_test, ... | ||
'solver_def_file', model_stage.solver_def_file, ... | ||
'net_file', model_stage.init_net_file, ... | ||
'cache_name', model_stage.cache_name, ... | ||
'exp_name', conf.exp_name); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
function model = set_cache_folder_caltech(cache_base_proposal, model) | ||
% model = set_cache_folder_caltech(cache_base_proposal, model) | ||
% -------------------------------------------------------- | ||
% RPN_BF | ||
% Copyright (c) 2016, Liliang Zhang | ||
% Licensed under The MIT License [see LICENSE for details] | ||
% -------------------------------------------------------- | ||
|
||
model.stage1_rpn.cache_name = [cache_base_proposal, '_stage1_rpn']; | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
function model = VGG16_for_rpn_pedestrian_caltech(exp_name, model) | ||
|
||
|
||
model.mean_image = fullfile(pwd, 'models', exp_name, 'pre_trained_models', 'vgg_16layers', 'mean_image'); | ||
model.pre_trained_net_file = fullfile(pwd, 'models', exp_name, 'pre_trained_models', 'vgg_16layers', 'vgg16.caffemodel'); | ||
% Stride in input image pixels at the last conv layer | ||
model.feat_stride = 16; | ||
|
||
%% stage 1 rpn, inited from pre-trained network | ||
model.stage1_rpn.solver_def_file = fullfile(pwd, 'models', exp_name, 'rpn_prototxts', 'vgg_16layers_conv3_1', 'solver_60k80k.prototxt'); | ||
model.stage1_rpn.test_net_def_file = fullfile(pwd, 'models', exp_name, 'rpn_prototxts', 'vgg_16layers_conv3_1', 'test.prototxt'); | ||
model.stage1_rpn.init_net_file = model.pre_trained_net_file; | ||
|
||
% rpn test setting | ||
model.stage1_rpn.nms.per_nms_topN = 10000; | ||
model.stage1_rpn.nms.nms_overlap_thres = 0.5; | ||
model.stage1_rpn.nms.after_nms_topN = 40; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
function script_rpn_pedestrian_VGG16_caltech() | ||
% script_rpn_pedestrian_VGG16_caltech() | ||
% -------------------------------------------------------- | ||
% RPN_BF | ||
% Copyright (c) 2016, Liliang Zhang | ||
% Licensed under TByrhe MIT License [see LICENSE for details] | ||
% -------------------------------------------------------- | ||
|
||
clc; | ||
clear mex; | ||
clear is_valid_handle; % to clear init_key | ||
run(fullfile(fileparts(fileparts(mfilename('fullpath'))), 'startup')); | ||
%% -------------------- CONFIG -------------------- | ||
opts.caffe_version = 'caffe_faster_rcnn'; | ||
opts.gpu_id = auto_select_gpu; | ||
active_caffe_mex(opts.gpu_id, opts.caffe_version); | ||
|
||
exp_name = 'VGG16_caltech'; | ||
feat_stride = 16; | ||
|
||
% do validation, or not | ||
opts.do_val = true; | ||
% model | ||
model = Model.VGG16_for_rpn_pedestrian_caltech(exp_name); | ||
% cache base | ||
cache_base_proposal = 'rpn_caltech_vgg_16layers'; | ||
cache_base_fast_rcnn = ''; | ||
% train/test data | ||
dataset = []; | ||
% use_flipped = true; | ||
% dataset = Dataset.caltech_trainval(dataset, 'train', use_flipped); | ||
dataset = Dataset.caltech_trainval(dataset, 'train'); | ||
% dataset = Dataset.caltech_test(dataset, 'test', false); | ||
dataset = Dataset.caltech_test(dataset, 'test'); | ||
|
||
% %% -------------------- TRAIN -------------------- | ||
% conf | ||
conf_proposal = proposal_config_caltech('image_means', model.mean_image, 'feat_stride', model.feat_stride); | ||
% set cache folder for each stage | ||
model = Faster_RCNN_Train.set_cache_folder_caltech(cache_base_proposal, model); | ||
% generate anchors and pre-calculate output size of rpn network | ||
|
||
conf_proposal.exp_name = exp_name; | ||
[conf_proposal.anchors, conf_proposal.output_width_map, conf_proposal.output_height_map] ... | ||
= proposal_prepare_anchors(conf_proposal, model.stage1_rpn.cache_name, model.stage1_rpn.test_net_def_file); | ||
|
||
|
||
conf_proposal.feat_stride = feat_stride; | ||
|
||
%% train | ||
fprintf('\n***************\nstage one RPN \n***************\n'); | ||
model.stage1_rpn = Faster_RCNN_Train.do_proposal_train_caltech(conf_proposal, dataset, model.stage1_rpn, opts.do_val); | ||
|
||
%% test | ||
cache_name = 'caltech'; | ||
method_name = 'RPN-ped'; | ||
Faster_RCNN_Train.do_proposal_test_caltech(conf_proposal, model.stage1_rpn, dataset.imdb_test, dataset.roidb_test, cache_name, method_name); | ||
|
||
end | ||
|
||
function [anchors, output_width_map, output_height_map] = proposal_prepare_anchors(conf, cache_name, test_net_def_file) | ||
[output_width_map, output_height_map] ... | ||
= proposal_calc_output_size_caltech(conf, test_net_def_file); | ||
anchors = proposal_generate_anchors_caltech(cache_name, ... | ||
'scales', 2.6*(1.3.^(0:8)), ... | ||
'ratios', [1 / 0.41], ... | ||
'exp_name', conf.exp_name); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
function [output_width_map, output_height_map] = proposal_calc_output_size_caltech(conf, test_net_def_file) | ||
% [output_width_map, output_height_map] = proposal_calc_output_size_caltech(conf, test_net_def_file) | ||
% -------------------------------------------------------- | ||
% RPN_BF | ||
% Copyright (c) 2016, Liliang | ||
% Licensed under The MIT License [see LICENSE for details] | ||
% -------------------------------------------------------- | ||
|
||
% caffe.init_log(fullfile(pwd, 'caffe_log')); | ||
caffe_net = caffe.Net(test_net_def_file, 'test'); | ||
|
||
% set gpu/cpu | ||
if conf.use_gpu | ||
caffe.set_mode_gpu(); | ||
else | ||
caffe.set_mode_cpu(); | ||
end | ||
|
||
% input = conf.scales:conf.max_size; | ||
% % if conf.max_size == 640 | ||
% % input = [480 640]; | ||
% % end | ||
% % caltech image size are fixed as 640x480 | ||
input = [conf.max_size conf.scales]; | ||
|
||
output_w = nan(size(input)); | ||
output_h = nan(size(input)); | ||
for i = 1:length(input) | ||
s = input(i); | ||
im_blob = single(zeros(s, s, 3, 1)); | ||
net_inputs = {im_blob}; | ||
|
||
% Reshape net's input blobs | ||
caffe_net.reshape_as_input(net_inputs); | ||
caffe_net.forward(net_inputs); | ||
|
||
cls_score = caffe_net.blobs('proposal_cls_score').get_data(); | ||
output_w(i) = size(cls_score, 1); | ||
output_h(i) = size(cls_score, 2); | ||
end | ||
|
||
output_width_map = containers.Map(input, output_w); | ||
output_height_map = containers.Map(input, output_h); | ||
|
||
caffe.reset_all(); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
function conf = proposal_config_caltech(varargin) | ||
% conf = proposal_config_caltech(varargin) | ||
% -------------------------------------------------------- | ||
% RPN_BF | ||
% Copyright (c) 2016, Liliang Zhang | ||
% Licensed under The MIT License [see LICENSE for details] | ||
% -------------------------------------------------------- | ||
|
||
ip = inputParser; | ||
|
||
%% training | ||
ip.addParamValue('use_gpu', gpuDeviceCount > 0, ... | ||
@islogical); | ||
|
||
% % whether drop the anchors that has edges outside of the image boundary | ||
% ip.addParamValue('drop_boxes_runoff_image', ... | ||
% true, @islogical); | ||
|
||
ip.addParamValue('drop_fg_boxes_runoff_image', ... | ||
true, @islogical); | ||
|
||
% Image scales -- the short edge of input image | ||
ip.addParamValue('scales', 720, @ismatrix); | ||
% Max pixel size of a scaled input image | ||
ip.addParamValue('max_size', 960, @isscalar); | ||
% Images per batch, only supports ims_per_batch = 1 currently | ||
ip.addParamValue('ims_per_batch', 1, @isscalar); | ||
% Minibatch size | ||
ip.addParamValue('batch_size', 128, @isscalar); | ||
% Fraction of minibatch that is foreground labeled (class > 0) | ||
ip.addParamValue('fg_fraction', 1/6, @isscalar); | ||
% weight of background samples, when weight of foreground samples is | ||
% 1.0 | ||
ip.addParamValue('bg_weight', 1.0, @isscalar); | ||
% Overlap threshold for a ROI to be considered foreground (if >= fg_thresh) | ||
ip.addParamValue('fg_thresh', 0.5, @isscalar); | ||
% Overlap threshold for a ROI to be considered background (class = 0 if | ||
% overlap in [bg_thresh_lo, bg_thresh_hi)) | ||
ip.addParamValue('bg_thresh_hi', 0.5, @isscalar); | ||
ip.addParamValue('bg_thresh_lo', 0, @isscalar); | ||
% mean image, in RGB order | ||
ip.addParamValue('image_means', 256, @ismatrix); | ||
% Use horizontally-flipped images during training? | ||
ip.addParamValue('use_flipped', false, @islogical); | ||
% Stride in input image pixels at ROI pooling level (network specific) | ||
% 16 is true for {Alex,Caffe}Net, VGG_CNN_M_1024, and VGG16 | ||
ip.addParamValue('feat_stride', 16, @isscalar); | ||
% train proposal target only to labled ground-truths or also include | ||
% other proposal results (selective search, etc.) | ||
ip.addParamValue('target_only_gt', true, @islogical); | ||
|
||
% random seed | ||
ip.addParamValue('rng_seed', 6, @isscalar); | ||
|
||
|
||
%% testing | ||
ip.addParamValue('test_scales', 720, @isscalar); | ||
ip.addParamValue('test_max_size', 960, @isscalar); | ||
ip.addParamValue('test_nms', 0.5, @isscalar); | ||
ip.addParamValue('test_binary', false, @islogical); | ||
ip.addParamValue('test_min_box_size',16, @isscalar); | ||
ip.addParamValue('test_min_box_height',50, @isscalar); | ||
ip.addParamValue('test_drop_boxes_runoff_image', ... | ||
false, @islogical); | ||
|
||
ip.parse(varargin{:}); | ||
conf = ip.Results; | ||
|
||
%assert(conf.ims_per_batch == 1, 'currently rpn only supports ims_per_batch == 1'); | ||
|
||
assert(conf.scales == conf.test_scales); | ||
assert(conf.max_size == conf.test_max_size); | ||
|
||
% if image_means is a file, load it | ||
if ischar(conf.image_means) | ||
s = load(conf.image_means); | ||
s_fieldnames = fieldnames(s); | ||
assert(length(s_fieldnames) == 1); | ||
conf.image_means = s.(s_fieldnames{1}); | ||
end | ||
end |
Oops, something went wrong.