Skip to content

Commit

Permalink
fix hardmax test cases make output dtype same as input (onnx#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumihwh authored and bddppq committed Apr 4, 2018
1 parent c970f0c commit f0f6b3d
Show file tree
Hide file tree
Showing 10 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3014,7 +3014,7 @@ expect(node, inputs=[x], outputs=[y],

```python
def hardmax_2d(x):
return np.eye(x.shape[1])[np.argmax(x, axis=1)]
return np.eye(x.shape[1], dtype=x.dtype)[np.argmax(x, axis=1)]

x = np.random.randn(3, 4, 5).astype(np.float32)
node = onnx.helper.make_node(
Expand Down
2 changes: 1 addition & 1 deletion onnx/backend/test/case/node/hardmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def export():
@staticmethod
def export_hardmax_axis():
def hardmax_2d(x):
return np.eye(x.shape[1])[np.argmax(x, axis=1)]
return np.eye(x.shape[1], dtype=x.dtype)[np.argmax(x, axis=1)]

x = np.random.randn(3, 4, 5).astype(np.float32)
node = onnx.helper.make_node(
Expand Down
Binary file modified onnx/backend/test/data/node/test_hardmax_axis_0/model.onnx
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion onnx/backend/test/data/node/test_hardmax_axis_1/model.onnx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

b
y
 



B
Binary file not shown.
2 changes: 1 addition & 1 deletion onnx/backend/test/data/node/test_hardmax_axis_2/model.onnx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

b
y
 



B
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

b
y
 



B
Binary file not shown.

0 comments on commit f0f6b3d

Please sign in to comment.