Skip to content

Commit

Permalink
[1/N] Non-Tensor: Scalar Support: Enable aot compile to support aten …
Browse files Browse the repository at this point in the history
…operations with scalar input like alpha (pytorch#124177)

Some operations have a scalar input parameter, like `torch.add(a, b, alpha=2.0)`.  Currently, the aot compile does not support such a case because it requires the signature of the captured graph to align with the operation's signature. This means that some inputs in the captured graph may be scalar(float, int, bool, etc.). It breaks the assumption of `compile_fx_aot` as it assumes all the example inputs are tensor - https://github.com/pytorch/pytorch/blob/0f6ce45bcbd7026c00da43db0317ede10830378b/torch/_inductor/compile_fx.py#L1048

This PR intends to support such cases by allowing not-aligned signature and filtering out the non-Tensor parameters.

Captured graph for `torch.add(a, b, alpha=2.0)`

```
opcode         name      target           args              kwargs
-------------  --------  ---------------  ----------------  --------------
placeholder    arg0_1    arg0_1           ()                {}
placeholder    arg1_1    arg1_1           ()                {}
call_function  add       aten.add.Tensor  (arg0_1, arg1_1)  {'alpha': 2.0}
output         output_1  output           ((add,),)         {}
```

Pull Request resolved: pytorch#124177
Approved by: https://github.com/jansel, https://github.com/desertfire, https://github.com/jgong5
  • Loading branch information
EikanWang authored and pytorchmergebot committed May 16, 2024
1 parent 5fa1f4c commit 08aa704
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 1 deletion.
19 changes: 19 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from unittest import skip

import torch
import torch._export
import torch._inductor
import torch.nn as nn
from torch._dynamo.testing import rand_strided, same
Expand Down Expand Up @@ -1210,6 +1211,24 @@ def forward(self, x):
torch._export.aot_compile(Model(self.device), example_inputs)
self.check_model(Model(self.device), example_inputs)

def test_non_tensor_input(self):
def fn(a, b, alpha=1.0):
return torch.add(a, b, alpha=alpha)

a = torch.randn(10, device=self.device)
b = torch.randn(10, device=self.device)
with self.assertRaises(RuntimeError):
torch._export.aot_compile(fn, args=(a, b), kwargs={"alpha": 2.0})

so_path = torch._export.aot_compile(
torch.ops.aten.add, args=(a, b), kwargs={"alpha": 2.0}, same_signature=False
)
kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path)
res = kernel_runner.run([a, b])
self.assertTrue(isinstance(res, list))
self.assertTrue(len(res) == 1)
self.assertEqual(fn(a, b, alpha=2.0), res[0])

def test_buffer_mutation_2(self):
class Model(torch.nn.Module):
def __init__(self, device):
Expand Down
2 changes: 2 additions & 0 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def aot_compile(
options: Optional[Dict[str, Any]] = None,
remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False,
same_signature: bool = True,
) -> str:
"""
Note: this function is not stable yet
Expand Down Expand Up @@ -393,6 +394,7 @@ def aot_compile(
kwargs,
dynamic_shapes,
disable_constraint_solver=disable_constraint_solver,
same_signature=same_signature,
# Disabling this flag, because instead we can rely on the mapping
# dynamo_flat_name_to_original_fqn which is coming from Dynamo.
restore_fqn=False,
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def aot_compile(
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, kwargs or {})
)
flat_example_inputs = tuple(x[1] for x in flat_args_with_path)
flat_example_inputs = tuple(
x[1] for x in flat_args_with_path if isinstance(x[1], torch.Tensor)
)

if in_spec is not None and received_spec != in_spec:
raise ValueError( # noqa: TRY200
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,10 @@ def aoti_compile_with_persistent_cache(
options=options,
remove_runtime_assertions=remove_runtime_assertions,
disable_constraint_solver=disable_constraint_solver,
# Some operations may have non-Tensor parameters like int, float, bool. These
# non-Tensor parameters will not be the input of the graph. Therefore, we do
# need to keep the same signature.
same_signature=False,
)

kernel_metadata_items = []
Expand Down
2 changes: 2 additions & 0 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def _export_to_torch_ir(
disable_constraint_solver: bool = False,
restore_fqn: bool = True,
_log_export_usage: bool = True,
same_signature: bool = True,
) -> torch.fx.GraphModule:
"""
Traces either an nn.Module's forward function or just a callable with PyTorch
Expand Down Expand Up @@ -445,6 +446,7 @@ def _export_to_torch_ir(
tracing_mode="symbolic",
disable_constraint_solver=disable_constraint_solver,
_log_export_usage=_log_export_usage,
same_signature=same_signature,
)(
*args,
**kwargs,
Expand Down

0 comments on commit 08aa704

Please sign in to comment.