Skip to content

Commit

Permalink
use scalar for OneHot's depth to prevent confusion (onnx#3774)
Browse files Browse the repository at this point in the history
Signed-off-by: Chun-Wei Chen <[email protected]>
  • Loading branch information
jcwchen authored Oct 22, 2021
1 parent 955f4ba commit d0151d7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -12978,7 +12978,7 @@ node = onnx.helper.make_node(
)
indices = np.array([[1, 9],
[2, 4]], dtype=np.float32)
depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand All @@ -13004,7 +13004,7 @@ node = onnx.helper.make_node(
)
indices = np.array([[1, 9],
[2, 4]], dtype=np.float32)
depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand Down Expand Up @@ -13035,7 +13035,7 @@ indices = np.array([0, -7, -8], dtype=np.int64)
# [1. 1. 1. 3. 1. 1. 1. 1. 1. 1.]
# [1. 1. 3. 1. 1. 1. 1. 1. 1. 1.]]

depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand Down
6 changes: 3 additions & 3 deletions docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -8164,7 +8164,7 @@ node = onnx.helper.make_node(
)
indices = np.array([[1, 9],
[2, 4]], dtype=np.float32)
depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand All @@ -8188,7 +8188,7 @@ node = onnx.helper.make_node(
)
indices = np.array([[1, 9],
[2, 4]], dtype=np.float32)
depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand Down Expand Up @@ -8217,7 +8217,7 @@ indices = np.array([0, -7, -8], dtype=np.int64)
# [1. 1. 1. 3. 1. 1. 1. 1. 1. 1.]
# [1. 1. 3. 1. 1. 1. 1. 1. 1. 1.]]

depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand Down
6 changes: 3 additions & 3 deletions onnx/backend/test/case/node/onehot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def export_with_axis(): # type: () -> None
)
indices = np.array([[1, 9],
[2, 4]], dtype=np.float32)
depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand All @@ -84,7 +84,7 @@ def export_with_negative_indices(): # type: () -> None
# [1. 1. 1. 3. 1. 1. 1. 1. 1. 1.]
# [1. 1. 3. 1. 1. 1. 1. 1. 1. 1.]]

depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand All @@ -104,7 +104,7 @@ def export_with_negative_axis(): # type: () -> None
)
indices = np.array([[1, 9],
[2, 4]], dtype=np.float32)
depth = np.array([10], dtype=np.float32)
depth = np.float32(10)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
Expand Down

0 comments on commit d0151d7

Please sign in to comment.