Skip to content

Commit

Permalink
[matlab] support feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
mli committed Dec 25, 2015
1 parent bbcfeba commit 737e117
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 72 deletions.
163 changes: 107 additions & 56 deletions matlab/+mxnet/model.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
predictor
% the previous input size
prev_input_size
% the previous device id, -1 means cpu
% the previous device id
prev_dev_id
% the previous device type (cpu or gpu)
prev_dev_type
% the previous output layers
prev_out_layers
end

methods
Expand All @@ -26,6 +30,7 @@
obj.prev_input_size = zeros(1,4);
obj.verbose = 1;
obj.prev_dev_id = -1;
obj.prev_dev_type = -1;
end

function delete(obj)
Expand Down Expand Up @@ -64,21 +69,11 @@ function load(obj, model_prefix, num_epoch)
end


function outputs = forward(obj, imgs, varargin)
function outputs = forward(obj, input, varargin)
%FORWARD perform forward
%
% OUT = MODEL.FORWARD(imgs) returns the forward (prediction) outputs of a list
% of images, where imgs can be either a single image with the format
%\
% width x height x channel
%
% which is return format of `imread` or a list of images with format
%
% width x height x channel x num_images
%
% MODEL.FORWARD(imgs, 'gpu', [0, 1]) uses GPU 0 and 1 for prediction
%
% MODEL.FORWARD(imgs, {'conv4', 'conv5'}) extract outputs for two internal layers
% OUT = MODEL.FORWARD(input) returns the forward (prediction) outputs of a list
% of input examples
%
% Examples
%
Expand All @@ -96,79 +91,135 @@ function load(obj, model_prefix, num_epoch)
% imgs(:,:,:,2) = img2
% out = model.forward(imgs, 'gpu', [0,1])

% check arguments
assert(length(varargin) == 0, 'sorry, not implemented yet..');
% parse arguments
dev_type = 1; % cpu in default
dev_id = 0;
out_layers = {};
while length(varargin) > 0
if ischar(varargin{1}) && strcmp(varargin{1}, 'gpu')
assert(length(varargin) > 1, 'arg error: no gpu id')
assert(isnumeric(varargin{2}))
dev_type = 2;
dev_id = varargin{2};
varargin = varargin(3:end);
continue
end

% convert from matlab order (col-major) into c order (row major):
siz = size(imgs);
if length(siz) == 2
imgs = permute(imgs, [2, 1]);
siz = [siz, 1, 1];
elseif length(siz) == 3
imgs = permute(imgs, [2, 1, 3]);
siz = [siz, 1];
elseif length(siz) == 4
imgs = permute(imgs, [2, 1, 3, 4]);
else
error('imgs shape error')
if ischar(varargin{1})
out_layers{end+1} = varargin{1};
varargin = varargin(2:end);
continue
end

if iscell(varargin{1})
out_layers = varargin{1};
varargin = varargin(2:end);
continue
end
end

if any(siz ~= obj.prev_input_size)
siz = size(input);
assert(length(siz) >= 2);

% convert from matlab order (col-major) into c order (row major):
input = obj.convert_ndarray(input);

if obj.changed(siz, dev_type, dev_id, out_layers)
obj.free_predictor()
end
obj.prev_input_size = siz;

dev_type = 1;
if obj.predictor.Value == 0
if obj.verbose
fprintf('create predictor with input size ');
fprintf('%d ', siz);
fprintf('\n');
end
callmxnet('MXPredCreate', obj.symbol, ...
fprintf('create predictor with input size ');
fprintf('%d ', siz);
fprintf('\n');
csize = [ones(1, 4-length(siz)), siz(end:-1:1)];
callmxnet('MXPredCreatePartialOut', obj.symbol, ...
libpointer('voidPtr', obj.params), ...
length(obj.params), ...
dev_type, 0, ...
int32(dev_type), int32(dev_id), ...
1, {'data'}, ...
uint32([0, 4]), ...
uint32(siz(end:-1:1)), ...
uint32(csize), ...
uint32(length(out_layers)), out_layers, ...
obj.predictor);
end

% feed input
callmxnet('MXPredSetInput', obj.predictor, 'data', single(imgs(:)), uint32(numel(imgs)));
callmxnet('MXPredSetInput', obj.predictor, 'data', single(input(:)), uint32(numel(input)));
% forward
callmxnet('MXPredForward', obj.predictor);

% get output size
out_dim = libpointer('uint32Ptr', 0);
out_shape = libpointer('uint32PtrPtr', ones(4,1));
callmxnet('MXPredGetOutputShape', obj.predictor, 0, out_shape, out_dim);
assert(out_dim.Value <= 4);
out_siz = out_shape.Value(1:out_dim.Value);
out_siz = double(out_siz(:)');

% get output
out = libpointer('singlePtr', single(ones(out_siz)));
num_out = 1;
if ~isempty(out_layers), num_out = length(out_layers); end

callmxnet('MXPredGetOutput', obj.predictor, 0, ...
out, uint32(prod(out_siz)));
if num_out == 1
outputs = obj.get_output(0);
else
outputs = cell(num_out,1);
for i = 1 : num_out
outputs{i} = obj.get_output(i-1);
end
end

% TODO convert from c order to matlab order...
outputs = out.Value;
end
end

methods (Access = private)
function free_predictor(obj)
% free the predictor
if obj.predictor.Value ~= 0
if obj.verbose
fprintf('destroy predictor\n')
end
callmxnet('MXPredFree', obj.predictor);
obj.predictor = libpointer('voidPtr', 0);
end
end

function Y = convert_ndarray(obj, X)
% convert between matlab's col major and c's row major
siz = size(X);
Y = permute(X, [2 1 3:length(siz)]);
end

function ret = changed(obj, input_size, dev_type, dev_id, out_layers)
% check if arguments changed since last call
ret = 0;
if length(input_size) ~= length(obj.prev_input_size) || ...
any(input_size ~= obj.prev_input_size) || ...
dev_type ~= obj.prev_dev_type || ...
length(dev_id) ~= length(obj.prev_dev_id) || ...
any(dev_id ~= obj.prev_dev_id) || ...
length(out_layers) ~= length(obj.prev_out_layers) || ...
~all(cellfun(@strcmp, out_layers, obj.prev_out_layers))
ret = 1;
end
obj.prev_input_size = input_size;
obj.prev_dev_type = dev_type;
obj.prev_dev_id = dev_id;
obj.prev_out_layers = out_layers;
end

function out = get_output(obj, index)
% get the i-th output
out_dim = libpointer('uint32Ptr', 0);
out_shape = libpointer('uint32PtrPtr', ones(4,1));
callmxnet('MXPredGetOutputShape', obj.predictor, index, out_shape, out_dim);
assert(out_dim.Value <= 4);
out_siz = out_shape.Value(1:out_dim.Value);
out_siz = double(out_siz(end:-1:1))';

% get output
out = libpointer('singlePtr', single(zeros(out_siz)));

callmxnet('MXPredGetOutput', obj.predictor, index, ...
out, uint32(prod(out_siz)));

% TODO convert from c order to matlab order...
out = reshape(out.Value, out_siz);
if length(out_siz) > 2
out = obj.convert_ndarray(out);
end
end

end

end
52 changes: 50 additions & 2 deletions matlab/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,56 @@

### How to use

The only requirment is build mxnet to get `lib/libmxnet.so`. Then run `demo` in
matlab.
The only requirment is build mxnet to get `lib/libmxnet.so`. Sample usage

- Load model and data:

```matlab
img = single(imresize(imread('cat.png'), [224 224])) - 120;
model = mxnet.model;
model.load('model/Inception_BN', 39);
```

- Get prediction:

```matlab
pred = model.forward(img);
```

- Do feature extraction on GPU 0:

```matlab
feas = model.forward(img, 'gpu', 0, {'max_pool_5b_pool', 'global_pool', 'fc'});
```

- See [demo.m](demo.m) for more examples

### Note on Implementation

We use `loadlibrary` to load mxnet library directly into Matlab and `calllib` to
call MXNet functions. Note that Matlab uses the column-major to store N-dim
arraies while and MXNet uses the row-major. So assume we create an array in
matlab with

```matlab
X = zeros([2,3,4,5]);
```

If we pass the memory of `X` into MXNet, then the correct shape will be
`[5,4,3,2]` in MXNet. When processing images, MXNet assumes the data layout is

```c++
example x channel x width x height
```

while in matlab we often store images by

```matlab
width x height x channel x example
```

So we should permuate the dimensions by `X = permute(X, [2, 1, 3, 4])` before
passing `X` into MXNet.

### FAQ

Expand Down
31 changes: 17 additions & 14 deletions matlab/demo.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

%% Load and resize the image
img = imresize(imread('cat.png'), [224 224]);

img = single(img) - 120;
%% Run prediction
pred = model.forward(img);

Expand All @@ -32,21 +32,24 @@

%% find the predict label
[p, i] = max(pred);

fprintf('the best result is %s, with probability %f\n', labels{i}, p)

%% Print all layers in the symbol
%% Print the last 10 layers in the symbol

sym = model.parse_symbol();
layers = {};
for i = 1 : length(sym.nodes)
if ~strcmp(sym.nodes{i}.op, 'null')
layers{end+1} = sym.nodes{i}.name;
end
end
fprintf('layer name: %s\n', layers{end-10:end})

% sym = model.parse_symbol();
% layers = {};
% for i = 1 : length(sym.nodes)
% if ~strcmp(sym.nodes{i}.op, 'null')
% layers{end+1} = sym.nodes{i}.name;
% end
% end
% layers
%% Extract feature from internal layers

%% Extract feature
feas = model.forward(img, {'max_pool_5b_pool', 'global_pool', 'fc'});
feas(:)

% TODO
%%
%% If GPU is available
% feas = model.forward(img, 'gpu', 0, {'max_pool_5b_pool', 'global_pool', 'fc'});
% feas(:)
36 changes: 36 additions & 0 deletions matlab/tests/prepare_data.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
%% download cifar10 dataset
system('wget https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz')
system('tar -xzvf cifar-10-matlab.tar.gz')
load cifar-10-batches-mat/test_batch.mat

%% convert test dataset of cifar10, and save
X = reshape(data', [32, 32, 3, 10000]);
X = permute(X, [2 1 3 4]);
Y = labels + 1;


save cifar10-test X Y
%% preview one picture
imshow(imresize(X(:,:,:,2), [128, 128]))

%%

!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
!gunzip t10k-images-idx3-ubyte.gz
!gunzip t10k-labels-idx1-ubyte.gz

%%

fid = fopen('t10k-images-idx3-ubyte', 'r');
d = fread(fid, inf, '*uint8');
fclose(fid);
X = reshape(d(17:end), [28 28 1 10000]);
X = permute(X, [2 1 3 4]);

fid = fopen('t10k-labels-idx1-ubyte', 'r');
d = fread(fid, inf, '*uint8');
fclose(fid);
Y = d(9:end) + 1;

save mnist-test X Y
Loading

0 comments on commit 737e117

Please sign in to comment.