Skip to content

Commit

Permalink
Add data_dependent_output tag; generalize proxy tensor to test it (py…
Browse files Browse the repository at this point in the history
…torch#83312)

Fixes pytorch#83251

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#83312
Approved by: https://github.com/albanD
  • Loading branch information
ezyang authored and pytorchmergebot committed Aug 12, 2022
1 parent d07a9ba commit d423722
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 29 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@

- func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool
variants: function, method
tags: data_dependent_output
dispatch:
CompositeExplicitAutograd: allclose

Expand Down Expand Up @@ -6382,6 +6383,7 @@
variants: function

- func: item(Tensor self) -> Scalar
tags: data_dependent_output
variants: method

- func: result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType
Expand All @@ -6403,6 +6405,7 @@

# NB: Does NOT check precondition that numel == 1
- func: _local_scalar_dense(Tensor self) -> Scalar
tags: data_dependent_output
dispatch:
CPU: _local_scalar_dense_cpu
CUDA: _local_scalar_dense_cuda
Expand Down Expand Up @@ -8567,6 +8570,7 @@
CPU, CUDA: unfold_backward

- func: equal(Tensor self, Tensor other) -> bool
tags: data_dependent_output
variants: method, function
dispatch:
CPU: cpu_equal
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/tags.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
desc: |
This tag indicates if an operator's output's shape depends on input Tensor
data.
- tag: data_dependent_output
desc: |
Operator has a non-Tensor output whose value is dependent on the data
of Tensor inputs. Among other things, this implies that this operator
cannot be run with meta tensor (since data is not available), nor
can it be symbolically traced.
- tag: generated
desc: |
This tag indicates that the operator doesn't have an explicit entry in
Expand Down
17 changes: 11 additions & 6 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,16 @@ def f():

self._test(f, [])

def test_constant_proxy_tensor(self):
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
def test_allclose(self):
def f(a, b):
return torch.allclose(a, b)

g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
self.assertRaisesRegex(
RuntimeError, "data-dependent",
lambda: make_fx(f, tracing_mode=self.tracing_mode)(
torch.zeros(3), torch.zeros(3)
)
)

def test_constant_proxy_tensor_mut(self):
def f():
Expand Down Expand Up @@ -701,6 +704,8 @@ def f(a, b):
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
xfail('quantile'),
xfail('nanquantile'),

# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
Expand Down
6 changes: 0 additions & 6 deletions test/test_public_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,6 @@ def test_no_new_bindings(self):

"wait",
"Tag",
"inplace_view",
"view_copy",
"generated",
"dynamic_output_shape",
"nondeterministic_bitwise",
"nondeterministic_seeded",
}
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}

Expand Down
3 changes: 1 addition & 2 deletions tools/autograd/templates/python_enum_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ namespace torch {
void initEnumTag(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::enum_<at::Tag>(m, "Tag")
${enum_of_valid_tags}
.export_values();
${enum_of_valid_tags};
m.doc() = "An Enum that contains tags that can be assigned to an operator registered in C++.";
}
}}
39 changes: 24 additions & 15 deletions torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .symbolic_shapes import ShapeEnv, magic_methods, reflectable_magic_methods
import torch.fx.experimental.symbolic_shapes as symbolic_shapes

__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict", "DecompositionInterpreter"]
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter"]
aten = torch.ops.aten

CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
Expand Down Expand Up @@ -90,12 +90,6 @@ def decompose(decomposition_table):
finally:
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table

# Checks whether we try to convert the tensor into a scalar
IS_STRICT = True
def enable_strict(val):
global IS_STRICT
IS_STRICT = val

def wrap_output(inner_res, proxy_res, *, constant, proxy_mode):
def wrap_with_proxy(e, proxy, constant):
if isinstance(e, torch.Tensor):
Expand Down Expand Up @@ -157,15 +151,30 @@ def proxy_call(proxy_mode, func_overload, args, kwargs=None):
r = func_overload.decompose(*args, **kwargs)
if r is not NotImplemented:
return r
if func_overload == aten._local_scalar_dense.default:
t, = args
assert not kwargs
if t.constant is not None:
if torch.Tag.data_dependent_output in func_overload.tags: # type: ignore[attr-defined]
# Check if all of the Tensor inputs are constants
all_constant = True

def try_unwrap_constant(t):
nonlocal all_constant
if isinstance(t, ProxyTensor):
if t.constant is not None:
return t.constant
else:
all_constant = False
return NotImplemented
else:
return t

const_args, const_kwargs = pytree.tree_map(try_unwrap_constant, (args, kwargs))

if all_constant:
with maybe_disable_fake_tensor_mode():
return t.constant.item()
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
"It's likely that this is caused by data-dependent control flow or similar."
"Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")
return func_overload(*const_args, **const_kwargs)
raise RuntimeError(
"It appears that you're trying to get value out of a tracing tensor - erroring out! "
"It's likely that this is caused by data-dependent control flow or similar."
)

def unwrap_proxy(e):
return e.proxy if isinstance(e, ProxyTensor) else e
Expand Down

0 comments on commit d423722

Please sign in to comment.