Skip to content

Commit

Permalink
Fix Batchnorm type inference for mean and variance and add a test (on…
Browse files Browse the repository at this point in the history
  • Loading branch information
jcwchen authored Apr 19, 2021
1 parent 8ea843e commit 1089b9e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions onnx/defs/nn/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1765,11 +1765,11 @@ ONNX_OPERATOR_SET_SCHEMA(
TensorShapeProto outputs_shape;
*outputs_shape.add_dim() = num_channels; // channel

propagateElemTypeFromInputToOutput(ctx, 0, 1);
propagateElemTypeFromInputToOutput(ctx, 3, 1);
updateOutputShape(ctx, 1, outputs_shape);

if (ctx.getNumOutputs() > 2) {
propagateElemTypeFromInputToOutput(ctx, 0, 2);
propagateElemTypeFromInputToOutput(ctx, 4, 2);
updateOutputShape(ctx, 2, outputs_shape);
}
}
Expand Down
15 changes: 15 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,21 @@ def test_batch_norm_train_dim_param(self): # type: () -> None
make_tensor_value_info('output_var', TensorProto.FLOAT, ('C',)), # type: ignore
])

def test_batch_norm_train_with_diff_type(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT16, (3, 4, 5, 6, 7)),
('scale', TensorProto.FLOAT16, (4,)),
('b', TensorProto.FLOAT16, (4,)),
('input_mean', TensorProto.FLOAT, (4,)),
('input_var', TensorProto.FLOAT, (4,))],
[make_node('BatchNormalization', ['x', 'scale', 'b', 'input_mean', 'input_var'],
['out', 'output_mean', 'output_var'], training_mode=1)],
[])
self._assert_inferred(graph, [make_tensor_value_info('out', TensorProto.FLOAT16, (3, 4, 5, 6, 7)), # type: ignore
make_tensor_value_info('output_mean', TensorProto.FLOAT, (4,)), # type: ignore
make_tensor_value_info('output_var', TensorProto.FLOAT, (4,)), # type: ignore
])

def test_batch_norm_test(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (3, 4, 5, 6, 7)),
Expand Down

0 comments on commit 1089b9e

Please sign in to comment.