Skip to content

Commit

Permalink
Enable nvfuser tests for refs.broadcast_to and refs.broadcast_tensors (
Browse files Browse the repository at this point in the history
…pytorch#84337)

Previously these tests were failing because they required some other op alongside prims.broadcast_in_dim to be executed. Now it works standalone.
Pull Request resolved: pytorch#84337
Approved by: https://github.com/mruberry, https://github.com/ngimel
  • Loading branch information
IvanYashchuk authored and pytorchmergebot committed Sep 6, 2022
1 parent aec76e3 commit 752c3bc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
5 changes: 1 addition & 4 deletions test/test_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ class TestPrims(TestCase):
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_broadcast_in_dim(self, device, dtype):
# nvfuser is not currently capable of realizing a broadcasted tensor
# when the broadcast is the only operation. Another op is needed.
def _wrapper(a, b, broadcast_dimensions):
a_bc = prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
return prims.add(a_bc, b)
return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)

traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16866,12 +16866,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
PythonRefInfo(
"_refs.broadcast_tensors",
torch_opinfo_name="broadcast_tensors",
supports_nvfuser=False,
validate_view_consistency=False,
),
PythonRefInfo(
"_refs.broadcast_to",
torch_opinfo_name="broadcast_to",
supports_nvfuser=False,
validate_view_consistency=False,
),
PythonRefInfo(
"_refs.cat",
Expand Down

0 comments on commit 752c3bc

Please sign in to comment.