Skip to content

Commit

Permalink
Merge pull request apache#1070 from mli/master
Browse files Browse the repository at this point in the history
  • Loading branch information
mli committed Dec 25, 2015
2 parents ac8f5c8 + 737e117 commit a5e51e4
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 90 deletions.
54 changes: 40 additions & 14 deletions example/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ recommend to use CUDNN.

| name | hardware | software |
| --- | --- | --- |
| GTX980 | dual Xeon E5-2680 v2, dual GTX 980, 1G Ethernet | GCC 4.8, CUDA 7.5, CUDNN v3 |
| EC2-g2.8x | Xeon E5-2670, dual GRID K520, 10G Ethernet | GCC 4.8, CUDA 7.5, CUDNN v3 |
| GTX980 | Xeon E5-1650 v3, 4 x GTX 980 | GCC 4.8, CUDA 7.5, CUDNN 3 |
| TitanX | dual Xeon E5-2630 v3, 4 x GTX Titan X | GCC 4.8, CUDA 7.5, CUDNN 3 |
| EC2-g2.8x | Xeon E5-2670, 2 x GRID K520, 10G Ethernet | GCC 4.8, CUDA 7.5, CUDNN 3 |

- Datasets

Expand Down Expand Up @@ -210,24 +211,48 @@ python train_cifar10.py --batch-size 128 --lr 0.1 --lr-factor .94 --num-epoch 50

### ILSVRC 12

#### `train_imagenet.py` with `--network alexnet`
<!-- #### Alexnet -->

- time for one epoch:
<!-- `train_imagenet.py` with `--network alexnet` -->

| 1 x GTX 980 | 2 x GTX 980 | 4 x GTX 980 |
| ----------- | ------------ | ------------ |
| 2,413 sec | 1,244 sec | 906 sec |
<!-- - time for one epoch: -->

#### `train_imagenet.py` with `--network inception-bn`
<!-- | 1 x GTX 980 | 2 x GTX 980 | 4 x GTX 980 | -->
<!-- | ----------- | ------------ | ------------ | -->
<!-- | 2,413 sec | 1,244 sec | 906 sec | -->

#### VGG

`train_imagenet.py` with `--network vgg`

- Performance

| Cluster | # machines | # GPUs | batch size | kvstore | epoch time |
| --- | --- | --- | --- | --- | ---: |
| TitanX | 1 | 1 | 96 | `none` | 14,545 |
| - | - | 2 | - | `local` | 19,692 |
| - | - | 4 | - | - | 20,014 |
| - | - | 2 | - | `local_allreduce_device` | 9,142 |
| - | - | 4 | - | - | 8,533 |
| - | - | - | 384 | - | 5,161 |

#### Inception with Batch Normalization

`train_imagenet.py` with `--network inception-bn`

- Performance

| Cluster | # machines | # GPUs | batch size | kvstore | epoch time |
| --- | --- | --- | --- | --- | ---: |
| GTX980 | 1 | 1 | 32 | `local` | 13,210 |
| - | 1 | 2 | 64 | `local` | 7,198 |
| - | 1 | 3 | 128 | `local` | 4,952 |
| - | 1 | 4 | 128 | `local` | 3,589 |
| - | - | 2 | 64 | - | 7,198 |
| - | - | 3 | 128 | - | 4,952 |
| - | - | 4 | - | - | 3,589 |
| TitanX | 1 | 1 | 128 | `none` | 10,666 |
| - | - | 2 | - | `local` | 5,161 |
| - | - | 3 | - | - | 3,460 |
| - | - | 4 | - | - | 2,844 |
| - | - | - | 512 | - | 2,495 |
| EC2-g2.8x | 1 | 4 | 144 | `local` | 14,203 |
| - | 10 | 40 | 144 | `dist_sync` | 1,422 |

Expand All @@ -236,8 +261,8 @@ python train_cifar10.py --batch-size 128 --lr 0.1 --lr-factor .94 --num-epoch 50
- `single machine` :

```bash
python train_imagenet.py --network inception-bn \
--batch-size 128 --lr 0.05 --num-epoch 60 --lr-factor .94 \
python train_imagenet.py --batch-size 144 --lr 0.05 --lr-factor .94 \
--gpus 0,1,2,3 --num-epoch 60 --network inception-bn \
--data-dir ilsvrc12/ --model-prefix model/ilsvrc12
```

Expand All @@ -251,7 +276,8 @@ python train_cifar10.py --batch-size 128 --lr 0.1 --lr-factor .94 --num-epoch 50
--data-dir s3://dmlc/ilsvrc12/ --model-prefix s3://dmlc/model/ilsvrc12
```

*Note: S3 is unstable sometimes, before fixing this problem, we recommend to download data to `/mnt` first*
*Note: S3 is unstable sometimes, if your training hangs or getting error
freqently, you cant download data to `/mnt` first*

Accuracy vs epoch ([the interactive figure](https://docs.google.com/spreadsheets/d/1AEesHjWUZOzCN0Gp_PYI1Cw4U1kZMKot360p9Fowmjw/pubchart?oid=1740787404&format=interactive)):

Expand Down
36 changes: 36 additions & 0 deletions include/mxnet/c_predict_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ typedef void *NDListHandle;
* \return The last error happened at the predictor.
*/
MXNET_DLL const char* MXGetLastError();

/*!
* \brief create a predictor
* \param symbol_json_str The JSON string of the symbol.
Expand Down Expand Up @@ -65,6 +66,41 @@ MXNET_DLL int MXPredCreate(const char* symbol_json_str,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
PredictorHandle* out);

/*!
* \brief create a predictor wich customized outputs
* \param symbol_json_str The JSON string of the symbol.
* \param param_bytes The in-memory raw bytes of parameter ndarray file.
* \param param_size The size of parameter ndarray file.
* \param dev_type The device type, 1: cpu, 2:gpu
* \param dev_id The device id of the predictor.
* \param num_input_nodes Number of input nodes to the net,
* For feedforward net, this is 1.
* \param input_keys The name of input argument.
* For feedforward net, this is {"data"}
* \param input_shape_indptr Index pointer of shapes of each input node.
* The length of this array = num_input_nodes + 1.
* For feedforward net that takes 4 dimensional input, this is {0, 4}.
* \param input_shape_data A flatted data of shapes of each input node.
* For feedforward net that takes 4 dimensional input, this is the shape data.
* \param num_output_nodes Number of output nodes to the net,
* \param output_keys The name of output argument.
* For example {"global_pool"}
* \param out The created predictor handle.
* \return 0 when success, -1 when failure.
*/

MXNET_DLL int MXPredCreatePartialOut(const char* symbol_json_str,
const void* param_bytes,
int param_size,
int dev_type, int dev_id,
mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
mx_uint num_output_nodes,
const char** output_keys,
PredictorHandle* out);
/*!
* \brief Get the shape of output node.
* The returned shape_data and shape_ndim is only valid before next call to MXPred function.
Expand Down
163 changes: 107 additions & 56 deletions matlab/+mxnet/model.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
predictor
% the previous input size
prev_input_size
% the previous device id, -1 means cpu
% the previous device id
prev_dev_id
% the previous device type (cpu or gpu)
prev_dev_type
% the previous output layers
prev_out_layers
end

methods
Expand All @@ -26,6 +30,7 @@
obj.prev_input_size = zeros(1,4);
obj.verbose = 1;
obj.prev_dev_id = -1;
obj.prev_dev_type = -1;
end

function delete(obj)
Expand Down Expand Up @@ -64,21 +69,11 @@ function load(obj, model_prefix, num_epoch)
end


function outputs = forward(obj, imgs, varargin)
function outputs = forward(obj, input, varargin)
%FORWARD perform forward
%
% OUT = MODEL.FORWARD(imgs) returns the forward (prediction) outputs of a list
% of images, where imgs can be either a single image with the format
%\
% width x height x channel
%
% which is return format of `imread` or a list of images with format
%
% width x height x channel x num_images
%
% MODEL.FORWARD(imgs, 'gpu', [0, 1]) uses GPU 0 and 1 for prediction
%
% MODEL.FORWARD(imgs, {'conv4', 'conv5'}) extract outputs for two internal layers
% OUT = MODEL.FORWARD(input) returns the forward (prediction) outputs of a list
% of input examples
%
% Examples
%
Expand All @@ -96,79 +91,135 @@ function load(obj, model_prefix, num_epoch)
% imgs(:,:,:,2) = img2
% out = model.forward(imgs, 'gpu', [0,1])

% check arguments
assert(length(varargin) == 0, 'sorry, not implemented yet..');
% parse arguments
dev_type = 1; % cpu in default
dev_id = 0;
out_layers = {};
while length(varargin) > 0
if ischar(varargin{1}) && strcmp(varargin{1}, 'gpu')
assert(length(varargin) > 1, 'arg error: no gpu id')
assert(isnumeric(varargin{2}))
dev_type = 2;
dev_id = varargin{2};
varargin = varargin(3:end);
continue
end

% convert from matlab order (col-major) into c order (row major):
siz = size(imgs);
if length(siz) == 2
imgs = permute(imgs, [2, 1]);
siz = [siz, 1, 1];
elseif length(siz) == 3
imgs = permute(imgs, [2, 1, 3]);
siz = [siz, 1];
elseif length(siz) == 4
imgs = permute(imgs, [2, 1, 3, 4]);
else
error('imgs shape error')
if ischar(varargin{1})
out_layers{end+1} = varargin{1};
varargin = varargin(2:end);
continue
end

if iscell(varargin{1})
out_layers = varargin{1};
varargin = varargin(2:end);
continue
end
end

if any(siz ~= obj.prev_input_size)
siz = size(input);
assert(length(siz) >= 2);

% convert from matlab order (col-major) into c order (row major):
input = obj.convert_ndarray(input);

if obj.changed(siz, dev_type, dev_id, out_layers)
obj.free_predictor()
end
obj.prev_input_size = siz;

dev_type = 1;
if obj.predictor.Value == 0
if obj.verbose
fprintf('create predictor with input size ');
fprintf('%d ', siz);
fprintf('\n');
end
callmxnet('MXPredCreate', obj.symbol, ...
fprintf('create predictor with input size ');
fprintf('%d ', siz);
fprintf('\n');
csize = [ones(1, 4-length(siz)), siz(end:-1:1)];
callmxnet('MXPredCreatePartialOut', obj.symbol, ...
libpointer('voidPtr', obj.params), ...
length(obj.params), ...
dev_type, 0, ...
int32(dev_type), int32(dev_id), ...
1, {'data'}, ...
uint32([0, 4]), ...
uint32(siz(end:-1:1)), ...
uint32(csize), ...
uint32(length(out_layers)), out_layers, ...
obj.predictor);
end

% feed input
callmxnet('MXPredSetInput', obj.predictor, 'data', single(imgs(:)), uint32(numel(imgs)));
callmxnet('MXPredSetInput', obj.predictor, 'data', single(input(:)), uint32(numel(input)));
% forward
callmxnet('MXPredForward', obj.predictor);

% get output size
out_dim = libpointer('uint32Ptr', 0);
out_shape = libpointer('uint32PtrPtr', ones(4,1));
callmxnet('MXPredGetOutputShape', obj.predictor, 0, out_shape, out_dim);
assert(out_dim.Value <= 4);
out_siz = out_shape.Value(1:out_dim.Value);
out_siz = double(out_siz(:)');

% get output
out = libpointer('singlePtr', single(ones(out_siz)));
num_out = 1;
if ~isempty(out_layers), num_out = length(out_layers); end

callmxnet('MXPredGetOutput', obj.predictor, 0, ...
out, uint32(prod(out_siz)));
if num_out == 1
outputs = obj.get_output(0);
else
outputs = cell(num_out,1);
for i = 1 : num_out
outputs{i} = obj.get_output(i-1);
end
end

% TODO convert from c order to matlab order...
outputs = out.Value;
end
end

methods (Access = private)
function free_predictor(obj)
% free the predictor
if obj.predictor.Value ~= 0
if obj.verbose
fprintf('destroy predictor\n')
end
callmxnet('MXPredFree', obj.predictor);
obj.predictor = libpointer('voidPtr', 0);
end
end

function Y = convert_ndarray(obj, X)
% convert between matlab's col major and c's row major
siz = size(X);
Y = permute(X, [2 1 3:length(siz)]);
end

function ret = changed(obj, input_size, dev_type, dev_id, out_layers)
% check if arguments changed since last call
ret = 0;
if length(input_size) ~= length(obj.prev_input_size) || ...
any(input_size ~= obj.prev_input_size) || ...
dev_type ~= obj.prev_dev_type || ...
length(dev_id) ~= length(obj.prev_dev_id) || ...
any(dev_id ~= obj.prev_dev_id) || ...
length(out_layers) ~= length(obj.prev_out_layers) || ...
~all(cellfun(@strcmp, out_layers, obj.prev_out_layers))
ret = 1;
end
obj.prev_input_size = input_size;
obj.prev_dev_type = dev_type;
obj.prev_dev_id = dev_id;
obj.prev_out_layers = out_layers;
end

function out = get_output(obj, index)
% get the i-th output
out_dim = libpointer('uint32Ptr', 0);
out_shape = libpointer('uint32PtrPtr', ones(4,1));
callmxnet('MXPredGetOutputShape', obj.predictor, index, out_shape, out_dim);
assert(out_dim.Value <= 4);
out_siz = out_shape.Value(1:out_dim.Value);
out_siz = double(out_siz(end:-1:1))';

% get output
out = libpointer('singlePtr', single(zeros(out_siz)));

callmxnet('MXPredGetOutput', obj.predictor, index, ...
out, uint32(prod(out_siz)));

% TODO convert from c order to matlab order...
out = reshape(out.Value, out_siz);
if length(out_siz) > 2
out = obj.convert_ndarray(out);
end
end

end

end
Loading

0 comments on commit a5e51e4

Please sign in to comment.