Skip to content

Commit

Permalink
make shape inference of BatchNorm layout neutral (dmlc#301)
Browse files Browse the repository at this point in the history
* make shape inference of BatchNorm layout neutral

* refactor to use the axis variable to do BatchNorm shape inference

* refactor to use the axis variable to do BatchNorm shape inference

* add unittest to the axis param for batch norm shape inference
  • Loading branch information
guopinglong authored and tqchen committed Jan 15, 2018
1 parent 00a88d4 commit e478ef5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/top/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,15 @@ DMLC_REGISTER_PARAMETER(BatchNormParam);
inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 5U)
<< "Input:[data, gamma, beta, moving_mean, moving_var]";
CHECK_EQ(out_shape->size(), 3U);
const TShape &dshape = in_shape->at(0);
if (dshape.ndim() == 0) return false;
TShape bshape({dshape[1]});
CHECK((size_t)param.axis < dshape.Size());

TShape bshape({dshape[param.axis]});
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, bshape);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 2, bshape);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 3, bshape);
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ def test_batchnorm():
sdict = infer_shape(y)
assert(sdict["bn_gamma"][0] == [20])

x = sym.Variable("x", shape=(10, 20, 30, 40))
y = sym.batch_norm(data=x, axis=0, epsilon=2e-5, name='bn')
sdict = infer_shape(y)
assert(sdict['bn_moving_var'][0] == [10])

y = sym.batch_norm(data=x, axis=1, epsilon=2e-5, name='bn')
sdict = infer_shape(y)
assert(sdict['bn_gamma'][0] == [20])

y = sym.batch_norm(data=x, axis=2, epsilon=2e-5, name='bn')
sdict = infer_shape(y)
assert(sdict['bn_beta'][0] == [30])

y = sym.batch_norm(data=x, axis=3, epsilon=2e-5, name='bn')
sdict = infer_shape(y)
assert(sdict['bn_moving_mean'][0] == [40])

def test_flatten():
x = sym.Variable("x", shape=(10, 20, 10))
Expand Down

0 comments on commit e478ef5

Please sign in to comment.