Skip to content

Commit

Permalink
[export] allow complex guards as runtime asserts (pytorch#127129)
Browse files Browse the repository at this point in the history
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.

For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](pytorch#124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.

In this PR, relying on [hybrid backed-unbacked symints](pytorch#121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `_allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.

Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:
```
# reshape
def forward(self, x, y):  # x: [s0, s1], y: [s2]
    return x.reshape([-1]) + y  # guard s0 * s1 = s2

This leads to the following exported program

class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0, s1]", y: "f32[s2]"):
        sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0)
        mul: "Sym(-s2)" = -1 * sym_size_int;  sym_size_int = None
        sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
        mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
        add: "Sym(s0*s1 - s2)" = mul + mul_1;  mul = mul_1 = None
        eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0;  add = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'");  eq = None

        view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]);  x = None
        add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y);  view = y = None
        return (add_1,)
```
Another case is symbol divisibility:
```
def forward(self, x):  # x: [s0, s1]
    return x.reshape([-1, x.shape[0] - 1])  # Eq(Mod(s0 * s1, s0 - 1), 0)
```

Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex conditions like the following:
```
class Foo(torch.nn.Module):
    def forward(self, x, y):
        # check that negation of first guard also shows up as runtime assertion
        if x.shape[0] == y.shape[0]:  # False
            return x + y
        elif x.shape[0] == y.shape[0] ** 3:  # False
            return x + 2, y + 3
        elif x.shape[0] ** 2 == y.shape[0] * 3:  # True
            return x * 2.0, y * 3.0
```
For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard.

One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.

As shown above, the runtime asserts appear as math ops in the graph, generated by the sympy interpreter, resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.

Ideally this PR would subsume and remove the recently added [_disable_forced_specializations](pytorch#124949) flag, but that flag still handles one additional case of specialization: single-variable equalities where the symbol is solvable for a concrete value: see this [PR](pytorch#126925)

This PR doesn't change any behavior around data-dependent errors/unbacked symints yet, that could be further work.

NOTE: will take naming change suggestions for the flag :)

Pull Request resolved: pytorch#127129
Approved by: https://github.com/avikchaudhuri
  • Loading branch information
pianpwk authored and pytorchmergebot committed May 29, 2024
1 parent cc6e72d commit 8a31c2a
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 73 deletions.
4 changes: 2 additions & 2 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9195,8 +9195,8 @@ def test_shape_env_equal_constructor(self):
ShapeEnv not equal: field values don't match:
==> settings: values don't match.
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False)
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False)
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False)
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False)
""",
)
self._replay_and_check(main)
Expand Down
177 changes: 157 additions & 20 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5039,8 +5039,9 @@ def forward(self, x):
export(f, (inputs,), dynamic_shapes=dynamic_shapes)

def test_disable_forced_specializations(self):
# case 1
# check disable_forced_specializations flag behaves correctly
# check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags
# both behave correctly, avoiding forced specializations and deferring to runtime.
# case 1: modulo guards
from torch.export import dims

class Mod4Reshape(torch.nn.Module):
Expand All @@ -5055,31 +5056,36 @@ def forward(self, x):
r".*dx = .* must be specialized to 10 because the guards generated for it are too complex(.*\n)*"
r".*dy = .* must be specialized to 72 because the guards generated for it are too complex(.*\n)*",
):
torch.export._trace._export(
export(
Mod4Reshape(),
inputs,
dynamic_shapes={"x": (dx, dy)},
strict=False,
_disable_forced_specializations=False,
)
ep = torch.export._trace._export(

torch.export._trace._export( # just check this successfully compiles
Mod4Reshape(),
inputs,
dynamic_shapes={"x": (dx, dy)},
strict=False,
_disable_forced_specializations=True,
)
ep = torch.export._trace._export(
Mod4Reshape(),
inputs,
dynamic_shapes={"x": (dx, dy)},
_allow_complex_guards_as_runtime_asserts=True,
)
out1 = ep.module()(torch.randn(8, 7))
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
out2 = ep.module()(torch.randn(4, 3))
self.assertEqual(out2.shape, torch.ones(3, 4, 1).shape)
out2 = ep.module()(torch.randn(12, 11))
self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape)
with self.assertRaisesRegex(
RuntimeError,
r"shape .*7, 4, -1.* is invalid for input of size 64",
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'",
):
ep.module()(torch.randn(8, 8)) # fail

# case 2
# case 2: 2d reshape
class FreeReshape(torch.nn.Module):
def forward(self, x, y, z):
return x.reshape([-1]) + y.reshape([-1]) + z # s0*s1 = s2*s3 = s4
Expand All @@ -5090,42 +5096,95 @@ def forward(self, x, y, z):
torch.randn(48),
)
dynamic_shapes = {
"x": [Dim(f"dx{i}") for i in range(2)],
"y": [Dim(f"dy{i}") for i in range(2)],
"z": [Dim(f"dz{i}") for i in range(1)],
"x": [Dim(f"dx{i}", min=2) for i in range(2)],
"y": [Dim(f"dy{i}", min=2) for i in range(2)],
"z": [Dim(f"dz{i}", min=4) for i in range(1)],
}
with self.assertRaisesRegex( # this will force specialize
torch._dynamo.exc.UserError,
r".*Specializations unexpectedly required(.*\n)*"
r".*dx0 = .* must be specialized to 6 because the guards generated for it are too complex(.*\n)*"
r".*dx1 = .* must be specialized to 8 because the guards generated for it are too complex(.*\n)*",
):
torch.export._trace._export(
export(
FreeReshape(),
inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
_disable_forced_specializations=False,
)
ep = torch.export._trace._export(
torch.export._trace._export(
FreeReshape(),
inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
_disable_forced_specializations=True,
)
ep = torch.export._trace._export(
FreeReshape(),
inputs,
dynamic_shapes=dynamic_shapes,
_allow_complex_guards_as_runtime_asserts=True,
)
out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48))
self.assertEqual(out1.shape, torch.ones(48).shape)
out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40))
self.assertEqual(out2.shape, torch.ones(40).shape)
with self.assertRaisesRegex(
RuntimeError,
r"The size of tensor a .* must match the size of tensor b .* at non-singleton dimension 0",
r"Runtime assertion failed for expression Eq\(s0\*s1 \- s2\*s3, 0\) on node 'eq.*'",
): # fail only at runtime
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail

# case 3: 3d reshape (previously failing with different issue)
class Reshape3d(torch.nn.Module):
def forward(self, x, y):
return x.reshape([-1]) + y # s0*s1*s2 = s3

inputs = (
torch.randn(4, 3, 2),
torch.randn(24),
)
dynamic_shapes = {
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
"y": (Dim("dy", min=8),),
}
with self.assertRaisesRegex( # this will force specialize
torch._dynamo.exc.UserError,
r".*Specializations unexpectedly required(.*\n)*"
r"Suggested fixes:(.*\n)*"
r".*dx0 = 4(.*\n)*"
r".*dx1 = 3(.*\n)*"
r".*dx2 = 2(.*\n)*"
r".*dy = 24(.*\n)*",
):
export(
Reshape3d(),
inputs,
dynamic_shapes=dynamic_shapes,
)

torch.export._trace._export(
Reshape3d(),
inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
_disable_forced_specializations=True,
)
ep = torch.export._trace._export(
Reshape3d(),
inputs,
dynamic_shapes=dynamic_shapes,
_allow_complex_guards_as_runtime_asserts=True,
)
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
self.assertEqual(out1.shape, torch.ones(126).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(s0\*s1\*s2 \- s3, 0\) on node 'eq.*'",
): # fail only at runtime
ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail

def test_disable_forced_specializations_errors(self):
# check error messages with disable_forced_specializations=False/True
# check error messages with disable_forced_specializations = False/True
class Foo(torch.nn.Module):
def forward(self, w, x, y, z):
return w.reshape([-1]) + x, y + z # simple: s0*s1 = s2, s3 = s4
Expand All @@ -5142,7 +5201,7 @@ def forward(self, w, x, y, z):
"y": [Dim("dy")], # y & z incorrect, export is supposed to fail.
"z": [Dim("dz")], # suggested fix should be to match these up.
}
with self.assertRaisesRegex( # if disable=False, suggested fixes should specialize 3, 4, 12.
with self.assertRaisesRegex( # if allow = False, suggested fixes should specialize 3, 4, 12.
torch._dynamo.exc.UserError,
r".*Specializations unexpectedly required(.*\n)*"
r"Suggested fixes:(.*\n)*"
Expand Down Expand Up @@ -5172,6 +5231,84 @@ def forward(self, w, x, y, z):
_disable_forced_specializations=True,
)

def test_reshape_view_helper(self):
# see: https://github.com/pytorch/pytorch/issues/126607
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = x.view(x.size(1), -1)
# torch/_refs/__init__/_reshape_view_helper() will generate guards on reshape kernel(?)
# Ne(s0, 20), so that reshape isn't no-op
# Ne(Mod(s0, 20), 0), so that reshape needs to first flatten [s0, 20, 16] -> [s0*20, 16]
# then split_dim -> [20, s0, 16]
# check that these show up in graph
return torch.nn.functional.softmax(
x, dim=0
) # don't think softmax actually creates any issues, just part of original test

model = Model()
x = torch.rand(1024, 20, 16)
dynamic_shapes = {"x": {0: Dim("batch")}}
ep = torch.export._trace._export(
model,
(x,),
dynamic_shapes=dynamic_shapes,
_allow_complex_guards_as_runtime_asserts=True,
)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s0, 20\)",
):
ep.module()(torch.randn(20, 20, 16))
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)",
):
ep.module()(torch.randn(400, 20, 16))
ep.module()(torch.randn(42, 20, 16))

def test_allow_explicit_guards_as_runtime_asserts(self):
# check that explicit guards are treated as runtime assertions
class Foo(torch.nn.Module):
def forward(self, x, y):
# check that negation of first guard also shows up as runtime assertion
if x.shape[0] == y.shape[0]: # False
return x + y
elif x.shape[0] == y.shape[0] ** 3: # False
return x + 2, y + 3
elif x.shape[0] ** 2 == y.shape[0] * 3: # True
return x * 2.0, y * 3.0

inputs = (torch.randn(6), torch.randn(12))
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
ep = torch.export._trace._export(
Foo(),
inputs,
dynamic_shapes=dynamic_shapes,
_allow_complex_guards_as_runtime_asserts=True,
)
# check forward pass
out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
self.assertEqual(out0.shape, torch.ones(9).shape)
self.assertEqual(out1.shape, torch.ones(27).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s0 \- s1, 0\)",
): # fail only at runtime
ep.module()(torch.randn(4), torch.randn(4)) # fail
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Ne\(s0 \- s1\**3, 0\)",
):
ep.module()(torch.randn(64), torch.randn(4)) # fail
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(s0\**2 \- 3\*s1, 0\)",
):
ep.module()(torch.randn(10), torch.randn(9)) # fail

def test_constant_aliasing(self):
class M1(torch.nn.Module):
def __init__(self, m2, foo):
Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,15 @@ def is_fbcode():
os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1"
)

# hybrid backed unbacked symints
prefer_deferred_runtime_asserts_over_guards = False

# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
# range constraints + dims + derived dims language, we raise constraint violation
# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
# and allows complex guards as runtime assertions in the graph.
_allow_complex_guards_as_runtime_asserts = False

# By default, dynamo will treat all ints as backed SymInts, which means (1) it
# will wait to see the int change over multiple runs before generalizing and
# (2) it will still always 0/1 specialize an int. When true, this knob
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,8 @@ def export(
assume_static_by_default: bool = False,
same_signature: bool = True,
disable_constraint_solver: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
_allow_complex_guards_as_runtime_asserts: bool = False,
_log_export_usage: bool = True,
**extra_kwargs,
) -> Callable[..., ExportResult]:
Expand Down Expand Up @@ -1304,6 +1306,8 @@ def result_capturing_wrapper(*graph_inputs):
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts,
):
opt_f = optimize_assert(
dynamo_normalization_capturing_compiler,
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def __init__(
tracked_fakes=self.tracked_fakes,
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
_allow_complex_guards_as_runtime_asserts=config._allow_complex_guards_as_runtime_asserts,
co_fields=self.co_fields,
)

Expand Down
20 changes: 17 additions & 3 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ def make_fake_params_buffers(


def make_fake_inputs(
nn_module, args, kwargs, dynamic_shapes, _is_torch_jit_trace=False
nn_module,
args,
kwargs,
dynamic_shapes,
_is_torch_jit_trace=False,
_allow_complex_guards_as_runtime_asserts=False,
):
"""
Given an nn module, example inputs, and constraints, return a new fake mode,
Expand Down Expand Up @@ -156,13 +161,22 @@ def make_fake_inputs(
"co_firstlineno": code.co_firstlineno,
}
fake_mode = FakeTensorMode(
shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields),
shape_env=ShapeEnv(
tracked_fakes=[],
co_fields=co_fields,
prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts,
_allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts,
),
allow_non_fake_inputs=True,
export=True,
)
else:
fake_mode = FakeTensorMode(
shape_env=ShapeEnv(tracked_fakes=[]),
shape_env=ShapeEnv(
tracked_fakes=[],
prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts,
_allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts,
),
allow_non_fake_inputs=True,
)
if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
Expand Down
Loading

0 comments on commit 8a31c2a

Please sign in to comment.