Skip to content

Commit

Permalink
Remove unnecessary decomposition_table= from test/test_prims.py (pyto…
Browse files Browse the repository at this point in the history
…rch#84188)

Follow-up to 83782

Pull Request resolved: pytorch#84188
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
  • Loading branch information
IvanYashchuk authored and pytorchmergebot committed Sep 6, 2022
1 parent 88b1cc8 commit e20f217
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions test/test_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def func(a):
return torch.digamma(a)

with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func, decomposition_table=torch._prims.context.nvfuser_decomp_table())(a)
gm = make_fx(func)(a)

# Check that the torch.digamma is not replaced with torch.ops.prims.digamma
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
Expand All @@ -286,7 +286,7 @@ def func(a):
return torch.sigmoid(torch.digamma(a))

with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func, decomposition_table=torch._prims.context.nvfuser_decomp_table())(a)
gm = make_fx(func)(a)

call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_sigmoid = any(
Expand Down Expand Up @@ -627,8 +627,16 @@ def fn1(x):
with TorchRefsNvfuserCapabilityMode() as mode:
self.assertFalse(fn0(x, y, 0.3, False))

# Autocast context has C++ level ATen calls that are hidden from
# TorchRefsNvfuserCapabilityMode that works only on Python level.
# The first call to make_fx records autocast C++ calls directly and
# doesn't have the chance to translate to nvprims. After the first
# call, "gm" contains explicit calls to torch.ops.aten and nothing
# is hidden, so the second call to make_fx actually translates
# recorded autocast dtype conversions to nvprims.
with torch.autocast("cuda"):
gm = make_fx(fn1, decomposition_table=torch._prims.context.nvfuser_decomp_table())(x)
gm = make_fx(fn1)(x)
gm = make_fx(gm)(x)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_to_copy = any(
torch.ops.aten._to_copy.default == node.target
Expand Down

0 comments on commit e20f217

Please sign in to comment.