Skip to content

Commit

Permalink
examples: imagenet: fix statistics computation
Browse files Browse the repository at this point in the history
  • Loading branch information
vedaldi committed Jul 5, 2016
1 parent f78b5e4 commit 998f9d9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
5 changes: 4 additions & 1 deletion examples/imagenet/cnn_imagenet.m
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@
else
train = find(imdb.images.set == 1) ;
images = fullfile(imdb.imageDir, imdb.images.name(train(1:100:end))) ;
[averageImage, rgbMean, rgbCovariance] = getImageStats(images, 'imageSize', [256 256]) ;
[averageImage, rgbMean, rgbCovariance] = getImageStats(images, ...
'imageSize', [256 256], ...
'numThreads', opts.numFetchThreads, ...
'gpus', opts.train.gpus) ;
save(imageStatsPath, 'averageImage', 'rgbMean', 'rgbCovariance') ;
end
[v,d] = eig(rgbCovariance) ;
Expand Down
17 changes: 9 additions & 8 deletions examples/imagenet/getImageStats.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
numGpus = numel(opts.gpus) ;
if numGpus > 0
fprintf('%s: resetting GPU device\n', mfilename) ;
clear mex ;
gpuDevice(opts.gpus(1))
end

Expand All @@ -27,23 +28,23 @@
'imageSize', opts.imageSize, ...
'useGpu', numGpus > 0) ;

z = reshape(permute(data,[3 1 2 4]),3,[]) ;
n = size(z,2) ;
z = reshape(shiftdim(data,2),3,[]) ;
rgbm1{end+1} = mean(z,2) ;
rgbm2{end+1} = z*z'/size(z,2) ;
avg{end+1} = mean(data, 4) ;
rgbm1{end+1} = sum(z,2)/n ;
rgbm2{end+1} = z*z'/n ;
time = toc(time) ;
fprintf(' %.1f Hz\n', numel(batch) / time) ;
end

averageImage = gather(mean(cat(4,avg{:}),4)) ;
rgbm1 = mean(cat(2,rgbm1{:}),2) ;
rgbm2 = mean(cat(3,rgbm2{:}),3) ;
rgbMean = gather(rgbm1) ;
rgbCovariance = gather(rgbm2 - rgbm1*rgbm1') ;
rgbm1 = gather(mean(cat(2,rgbm1{:}),2)) ;
rgbm2 = gather(mean(cat(3,rgbm2{:}),3)) ;
rgbMean = rgbm1 ;
rgbCovariance = rgbm2 - rgbm1*rgbm1' ;

if numGpus > 0
fprintf('%s: finished with GPU device, resetting again\n', mfilename) ;
clear mex ;
gpuDevice(opts.gpus(1)) ;
end
fprintf('%s: all done\n', mfilename) ;

0 comments on commit 998f9d9

Please sign in to comment.