forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[export] allow complex guards as runtime asserts (pytorch#127129)
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that. For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](pytorch#124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs. In this PR, relying on [hybrid backed-unbacked symints](pytorch#121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](https://github.com/pytorch/pytorch/blob/d7de4c9d809697b36ae0fd9e16815f6e3b4d985b/torch/fx/experimental/symbolic_shapes.py#L1824), we add a flag `_allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph. Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes: ``` # reshape def forward(self, x, y): # x: [s0, s1], y: [s2] return x.reshape([-1]) + y # guard s0 * s1 = s2 This leads to the following exported program class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1]", y: "f32[s2]"): sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0) mul: "Sym(-s2)" = -1 * sym_size_int; sym_size_int = None sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1) mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2; sym_size_int_1 = sym_size_int_2 = None add: "Sym(s0*s1 - s2)" = mul + mul_1; mul = mul_1 = None eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0; add = None _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'"); eq = None view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]); x = None add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y); view = y = None return (add_1,) ``` Another case is symbol divisibility: ``` def forward(self, x): # x: [s0, s1] return x.reshape([-1, x.shape[0] - 1]) # Eq(Mod(s0 * s1, s0 - 1), 0) ``` Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex conditions like the following: ``` class Foo(torch.nn.Module): def forward(self, x, y): # check that negation of first guard also shows up as runtime assertion if x.shape[0] == y.shape[0]: # False return x + y elif x.shape[0] == y.shape[0] ** 3: # False return x + 2, y + 3 elif x.shape[0] ** 2 == y.shape[0] * 3: # True return x * 2.0, y * 3.0 ``` For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard. One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given. As shown above, the runtime asserts appear as math ops in the graph, generated by the sympy interpreter, resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related. Ideally this PR would subsume and remove the recently added [_disable_forced_specializations](pytorch#124949) flag, but that flag still handles one additional case of specialization: single-variable equalities where the symbol is solvable for a concrete value: see this [PR](pytorch#126925) This PR doesn't change any behavior around data-dependent errors/unbacked symints yet, that could be further work. NOTE: will take naming change suggestions for the flag :) Pull Request resolved: pytorch#127129 Approved by: https://github.com/avikchaudhuri
- Loading branch information
1 parent
cc6e72d
commit 8a31c2a
Showing
10 changed files
with
326 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.