Skip to content

Commit

Permalink
a few clarification
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyang committed Jan 11, 2015
1 parent bb4b93d commit b028561
Showing 1 changed file with 86 additions and 84 deletions.
170 changes: 86 additions & 84 deletions src/train_model.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
function LBFRegModel = train_model(dbnames)
%TRAIN_MODEL Summary of this function goes here
% Function: train face alignment model
% Detailed explanation goes here
Expand All @@ -15,23 +14,32 @@
end

if sum(strcmp(dbnames, 'COFW')) > 0
load('..\initial_shape\InitialShape_29.mat');
load('../initial_shape/InitialShape_29.mat');
params.meanshape = S0;
else
load('..\initial_shape\InitialShape_68.mat');
params.meanshape = S0(params.ind_usedpts, :);
load('../initial_shape/InitialShape_68.mat');
params.meanshape = S0;
end


if params.isparallel
if isempty(gcp('nocreate')) %判断并行计算环境是否已然启动
parpool(4); %若尚未启动,则启动并行环境
disp('Attention, if error occurs, plese ensure you used the correct version of parallel initialization.');

if isempty(gcp('nocreate'))
parpool(4);
else
disp('Already initialized'); %说明并行环境已经启动。
disp('Already initialized');
end

%{
if matlabpool('size') <= 0
matlabpool('open','local',4);
else
disp('Already initialized');
end
%}
end


% load trainning data from hardware
Tr_Data = [];
% Tr_Bboxes = [];
Expand Down Expand Up @@ -61,18 +69,25 @@
Data_flip{i}.bbox_gt = Data{i}.bbox_gt;
Data_flip{i}.bbox_gt(1) = Data_flip{i}.width - Data_flip{i}.bbox_gt(1) - Data_flip{i}.bbox_gt(3);

% Data_flip{i}.bbox_facedet = Data{i}.bbox_facedet;
% Data_flip{i}.bbox_facedet(1) = Data_flip{i}.width - Data_flip{i}.bbox_facedet(1) - Data_flip{i}.bbox_facedet(3);
Data_flip{i}.bbox_facedet = Data{i}.bbox_facedet;
Data_flip{i}.bbox_facedet(1) = Data_flip{i}.width - Data_flip{i}.bbox_facedet(1) - Data_flip{i}.bbox_facedet(3);
end
Data = [Data; Data_flip];
end

% choose corresponding points for training
for i = 1:length(Data)
Data{i}.shape_gt = Data{i}.shape_gt(params.ind_usedpts, :);
Data{i}.shape_gt = Data{i}.shape_gt(Param.ind_usedpts, :);
Data{i}.bbox_gt = getbbox(Data{i}.shape_gt);

% modify detection boxes
shape_facedet = resetshape(Data{i}.bbox_facedet, Param.meanshape);
shape_facedet = shape_facedet(Param.ind_usedpts, :);
Data{i}.bbox_facedet = getbbox(shape_facedet);
end

Param.meanshape = S0(Param.ind_usedpts, :);

dbsize = length(Data);

% load('Ts_bbox.mat');
Expand All @@ -86,7 +101,7 @@

indice_rotate = ceil(dbsize*rand(1, augnumber));
indice_shift = ceil(dbsize*rand(1, augnumber));
scales = 1 + 0.3*(rand([1 augnumber]) - 0.5);
scales = 1 + 0.2*(rand([1 augnumber]) - 0.5);

Data{i}.intermediate_shapes = cell(1, Param.max_numstage);
Data{i}.intermediate_bboxes = cell(1, Param.max_numstage);
Expand All @@ -97,70 +112,63 @@
Data{i}.shapes_residual = zeros([size(Param.meanshape), augnumber]);
Data{i}.tf2meanshape = cell(augnumber, 1);
Data{i}.meanshape2tf = cell(augnumber, 1);



% if Data{i}.isdet == 1
% Data{i}.bbox_facedet = Data{i}.bbox_facedet*ts_bbox;
% end



for s = 1:Param.augnumber_shift+1
for r = 1:Param.augnumber_rotate+1
for e = 1:Param.augnumber_scale+1
sr = (s-1)*(Param.augnumber_rotate+1)*(Param.augnumber_scale+1) + (r-1)*((Param.augnumber_scale+1)) + e;
if s == 1 && r == 1 && e == 1% initialize as meanshape
% estimate the similarity transformation from initial shape to mean shape
Data{i}.intermediate_shapes{1}(:,:, sr) = resetshape(Data{i}.bbox_gt, Param.meanshape);
Data{i}.intermediate_bboxes{1}(sr, :) = Data{i}.bbox_gt;

meanshape_resize = resetshape(Data{i}.intermediate_bboxes{1}(sr, :), Param.meanshape);


Data{i}.tf2meanshape{1} = cp2tform(bsxfun(@minus, Data{i}.intermediate_shapes{1}(1:end,:, 1), mean(Data{i}.intermediate_shapes{1}(1:end,:, 1))), ...
bsxfun(@minus, meanshape_resize(1:end, :), mean(meanshape_resize(1:end, :))), 'nonreflective similarity');
Data{i}.meanshape2tf{1} = cp2tform(bsxfun(@minus, meanshape_resize(1:end, :), mean(meanshape_resize(1:end, :))), ...
bsxfun(@minus, Data{i}.intermediate_shapes{1}(1:end,:, 1), mean(Data{i}.intermediate_shapes{1}(1:end,:, 1))), 'nonreflective similarity');


% calculate the residual shape from initial shape to groundtruth shape under normalization scale
shape_residual = bsxfun(@rdivide, Data{i}.shape_gt - Data{i}.intermediate_shapes{1}(:,:, 1), [Data{i}.intermediate_bboxes{1}(1, 3) Data{i}.intermediate_bboxes{1}(1, 4)]);
% transform the shape residual in the image coordinate to the mean shape coordinate
Data{i}.shapes_residual(:, :, 1) = tformfwd(Data{i}.tf2meanshape{1}, shape_residual(:, 1), shape_residual(:, 2));
else % randomly shift and rotate the meanshape (or groundtruth of other ssubjects)
% randomly rotate the shape

shape = resetshape(Data{i}.bbox_gt, Param.meanshape); % Data{indice_rotate(sr)}.shape_gt

% shape = scaleshape(shape, scales(sr));

% shape = rotateshape(shape);

% randomly shift the shape
shape = translateshape(shape, Data{indice_shift(sr)}.shape_gt);

Data{i}.intermediate_shapes{1}(:, :, sr) = shape;
Data{i}.intermediate_bboxes{1}(sr, :) = getbbox(shape);

meanshape_resize = resetshape(Data{i}.intermediate_bboxes{1}(sr, :), Param.meanshape);


Data{i}.tf2meanshape{sr} = cp2tform(bsxfun(@minus, Data{i}.intermediate_shapes{1}(1:end,:, sr), mean(Data{i}.intermediate_shapes{1}(1:end,:, sr))), ...
bsxfun(@minus, meanshape_resize(1:end, :), mean(meanshape_resize(1:end, :))), 'nonreflective similarity');
Data{i}.meanshape2tf{sr} = cp2tform(bsxfun(@minus, meanshape_resize(1:end, :), mean(meanshape_resize(1:end, :))), ...
bsxfun(@minus, Data{i}.intermediate_shapes{1}(1:end,:, sr), mean(Data{i}.intermediate_shapes{1}(1:end,:, sr))), 'nonreflective similarity');


shape_residual = bsxfun(@rdivide, Data{i}.shape_gt - Data{i}.intermediate_shapes{1}(:,:, sr), [Data{i}.intermediate_bboxes{1}(sr, 3) Data{i}.intermediate_bboxes{1}(sr, 4)]);

Data{i}.shapes_residual(:, :, sr) = tformfwd(Data{i}.tf2meanshape{sr}, shape_residual(:, 1), shape_residual(:, 2));
for sr = 1:params.augnumber
if sr == 1
% estimate the similarity transformation from initial shape to mean shape
% Data{i}.intermediate_shapes{1}(:,:, sr) = resetshape(Data{i}.bbox_gt, Param.meanshape);
% Data{i}.intermediate_bboxes{1}(sr, :) = Data{i}.bbox_gt;
Data{i}.intermediate_shapes{1}(:,:, sr) = resetshape(Data{i}.bbox_facedet, Param.meanshape);
Data{i}.intermediate_bboxes{1}(sr, :) = Data{i}.bbox_facedet;

meanshape_resize = resetshape(Data{i}.intermediate_bboxes{1}(sr, :), Param.meanshape);

Data{i}.tf2meanshape{1} = fitgeotrans(bsxfun(@minus, Data{i}.intermediate_shapes{1}(1:end,:, 1), mean(Data{i}.intermediate_shapes{1}(1:end,:, 1))), ...
(bsxfun(@minus, meanshape_resize(1:end, :), mean(meanshape_resize(1:end, :)))), 'NonreflectiveSimilarity');
Data{i}.meanshape2tf{1} = fitgeotrans((bsxfun(@minus, meanshape_resize(1:end, :), mean(meanshape_resize(1:end, :)))), ...
bsxfun(@minus, Data{i}.intermediate_shapes{1}(1:end,:, 1), mean(Data{i}.intermediate_shapes{1}(1:end,:, 1))), 'NonreflectiveSimilarity');

% calculate the residual shape from initial shape to groundtruth shape under normalization scale
shape_residual = bsxfun(@rdivide, Data{i}.shape_gt - Data{i}.intermediate_shapes{1}(:,:, 1), [Data{i}.intermediate_bboxes{1}(1, 3) Data{i}.intermediate_bboxes{1}(1, 4)]);
% transform the shape residual in the image coordinate to the mean shape coordinate
[u, v] = transformPointsForward(Data{i}.tf2meanshape{1}, shape_residual(:, 1)', shape_residual(:, 2)');
Data{i}.shapes_residual(:, 1, 1) = u';
Data{i}.shapes_residual(:, 2, 1) = v';
else
% randomly rotate the shape
% shape = resetshape(Data{i}.bbox_gt, Param.meanshape); % Data{indice_rotate(sr)}.shape_gt
shape = resetshape(Data{i}.bbox_facedet, Param.meanshape); % Data{indice_rotate(sr)}.shape_gt

if params.augnumber_scale ~= 0
shape = scaleshape(shape, scales(sr));
end

%{
drawshapes(Data{i}.img_gray, [Data{i}.shape_gt Data{i}.intermediate_shapes{1}(:, :, sr)]);
hold off;
%}
if params.augnumber_rotate ~= 0
shape = rotateshape(shape);
end

if params.augnumber_shift ~= 0
shape = translateshape(shape, Data{indice_shift(sr)}.shape_gt);
end

Data{i}.intermediate_shapes{1}(:, :, sr) = shape;
Data{i}.intermediate_bboxes{1}(sr, :) = getbbox(shape);

meanshape_resize = resetshape(Data{i}.intermediate_bboxes{1}(sr, :), Param.meanshape);

Data{i}.tf2meanshape{sr} = fitgeotrans(bsxfun(@minus, Data{i}.intermediate_shapes{1}(1:end,:, sr), mean(Data{i}.intermediate_shapes{1}(1:end,:, sr))), ...
bsxfun(@minus, meanshape_resize(1:end, :), mean(meanshape_resize(1:end, :))), 'NonreflectiveSimilarity');
Data{i}.meanshape2tf{sr} = fitgeotrans(bsxfun(@minus, meanshape_resize(1:end, :), mean(meanshape_resize(1:end, :))), ...
bsxfun(@minus, Data{i}.intermediate_shapes{1}(1:end,:, sr), mean(Data{i}.intermediate_shapes{1}(1:end,:, sr))), 'NonreflectiveSimilarity');

shape_residual = bsxfun(@rdivide, Data{i}.shape_gt - Data{i}.intermediate_shapes{1}(:,:, sr), [Data{i}.intermediate_bboxes{1}(sr, 3) Data{i}.intermediate_bboxes{1}(sr, 4)]);
[u, v] = transformPointsForward(Data{i}.tf2meanshape{1}, shape_residual(:, 1)', shape_residual(:, 2)');
Data{i}.shapes_residual(:, 1, sr) = u';
Data{i}.shapes_residual(:, 2, sr) = v';
% Data{i}.shapes_residual(:, :, sr) = tformfwd(Data{i}.tf2meanshape{sr}, shape_residual(:, 1), shape_residual(:, 2));
end
end
end
Expand All @@ -173,15 +181,15 @@
if nargin > 2
n = size(LBFRegModel_initial.ranf, 1);
for i = 1:n
randf(1:n, :) = LBFRegModel_initial.ranf;
randf(1:n, :) = LBFRegModel_initial.ranf;
end
end
%}

for s = 1:Param.max_numstage
% learn random forest for s-th stage
disp('train random forests for landmarks...');

%{
if isempty(randf{s})
if exist(strcat('randfs\randf', num2str(s), '.mat'))
Expand All @@ -200,15 +208,11 @@

% derive binary codes given learned random forest in current stage
disp('extract local binary features...');

if exist(strcat('LBFeats\LBFeats', num2str(s), '.mat'))
load(strcat('LBFeats\LBFeats', num2str(s), '.mat'));
else
tic;
binfeatures = derivebinaryfeat(randf{s}, Data, Param, s);
% save(strcat('LBFeats\LBFeats', num2str(s), '.mat'), 'binfeatures', '-v7.3');
toc;
end

tic;
binfeatures = derivebinaryfeat(randf{s}, Data, Param, s);
% save(strcat('LBFeats\LBFeats', num2str(s), '.mat'), 'binfeatures', '-v7.3');
toc;

% learn global linear regrassion given binary feature
disp('learn global regressors...');
Expand All @@ -221,6 +225,4 @@
end

LBFRegModel.ranf = randf;
LBFRegModel.Ws = Ws;

end
LBFRegModel.Ws = Ws;

0 comments on commit b028561

Please sign in to comment.