Skip to content

Commit

Permalink
Revert "Add kwargs support to torch.export() API (pytorch#92013)"
Browse files Browse the repository at this point in the history
This reverts commit 890b682.

Reverted pytorch#92013 on behalf of https://github.com/DanilBaibak due to Break internal build
  • Loading branch information
pytorchmergebot committed Jan 16, 2023
1 parent 76c8836 commit 1a98c3e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 119 deletions.
78 changes: 0 additions & 78 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,84 +1562,6 @@ def f(x):
inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))

def test_export_with_kwargs(self):
def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
out = pos0
for arg in tuple0:
out *= arg
for arg in myargs:
out *= arg
out *= mykw0
out *= mykwargs["input0"] * mykwargs["input1"]
return out

mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
tuple0 = (torch.randn(4), torch.randn(4))
mykw0 = torch.randn(4)
pos0 = torch.randn(4)
myargs = [torch.randn(4), torch.randn(4)]

torch._dynamo.reset()
exported = torch._dynamo.export(
fn_with_kwargs,
pos0,
tuple0,
*myargs,
aten_graph=False,
mykw0=mykw0,
**mykwargs,
)

out_graph = exported[0]
dynamo_result = out_graph(pos0, tuple0, *myargs, mykw0=mykw0, **mykwargs)
real_result = fn_with_kwargs(pos0, tuple0, *myargs, mykw0=mykw0, **mykwargs)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

def test_export_with_kwargs_and_empty_args(self):
def fn_with_kwargs(mykw0=None, **mykwargs):
out = mykw0
out *= mykwargs["input0"] * mykwargs["input1"]
return out

mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
mykw0 = torch.randn(4)

torch._dynamo.reset()
exported = torch._dynamo.export(
fn_with_kwargs,
aten_graph=False,
mykw0=mykw0,
**mykwargs,
)

out_graph = exported[0]
dynamo_result = out_graph(mykw0=mykw0, **mykwargs)
real_result = fn_with_kwargs(mykw0=mykw0, **mykwargs)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

def test_export_with_args_and_empty_kwargs(self):
def fn_with_kwargs(pos0, tuple0, *myargs):
out = pos0
for arg in tuple0:
out *= arg
for arg in myargs:
out *= arg
return out

tuple0 = (torch.randn(4), torch.randn(4))
pos0 = torch.randn(4)
myargs = [torch.randn(4), torch.randn(4)]

torch._dynamo.reset()
exported = torch._dynamo.export(
fn_with_kwargs, pos0, tuple0, *myargs, aten_graph=False
)

out_graph = exported[0]
dynamo_result = out_graph(pos0, tuple0, *myargs)
real_result = fn_with_kwargs(pos0, tuple0, *myargs)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
6 changes: 3 additions & 3 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ def result_capturing_wrapper(*graph_inputs):

return result_capturing_wrapper

flat_args, in_spec = pytree.tree_flatten((args, kwargs))
# TODO(voz): Handle kwargs properly?
flat_args, in_spec = pytree.tree_flatten(args)

remove_from_cache(f)
with patch(f"{__name__}.most_recent_backend", None):
Expand Down Expand Up @@ -682,10 +683,9 @@ def graph_with_interpreter(*args):
).transform()

# Make dynamo graph to have same input/output spec as user code
input_strs = [f"orig_arg_{i}" for i in range(len(args))] + list(kwargs.keys())
new_graph.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
input_strs,
[f"orig_arg_{i}" for i in range(len(args))],
in_spec,
out_spec_traced,
)
Expand Down
44 changes: 6 additions & 38 deletions torch/fx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,49 +622,17 @@ def process_outputs(self, out: Any) -> Any:
return pytree.tree_unflatten(out, self.pytree_info.out_spec)

def gen_fn_def(self, free_vars, maybe_return_annotation):
# Given a user function/model:
# myargs = [myargs0, myargs1]
# mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
# def forward(self, mypos, *myargs, mykey=None, **mykwargs):
#
# The generated code flattens all keywords into positional arguments for `forward()`
# e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1):
#
# Within `forward`, `tree_flatten_spec``still parses args and kwargs separately
# e.g. tree_flatten_spec(([mypos, myargs0, myargs1],
# {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}),
# self._in_spec)
#
# If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
# e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
if self.pytree_info is None:
return super().gen_fn_def(free_vars, maybe_return_annotation)

fn_args = self.pytree_info.orig_args
has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False
function_args = self.pytree_info.orig_args
has_orig_self = (function_args[0] == 'self') if len(function_args) > 0 else False
if has_orig_self:
free_vars.insert(0, 'self')
fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)

function_definition = super().gen_fn_def(function_args[:], maybe_return_annotation)
if len(free_vars) > 0: # pytree has placeholders in it
# when kwargs is present, in_spec is tuple(args, kwargs)
has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \
len(self.pytree_info.in_spec.children_specs) == 2 and \
self.pytree_info.in_spec.children_specs[0].type == tuple and \
self.pytree_info.in_spec.children_specs[1].type == dict
fn_kwargs = '{}'
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
if has_args_kwargs_tuple:
count_args = len(self.pytree_info.in_spec.children_specs[0].children_specs)
fn_args = self.pytree_info.orig_args[:count_args]
fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip(
self.pytree_info.in_spec.children_specs[1].context,
self.pytree_info.orig_args[count_args:])) + '}'
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"

fn_definition += f"""
{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
return fn_definition
function_definition += f"""
{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(function_args)}], self._in_spec)"""
return function_definition

def generate_output(self, output_args):
if self.pytree_info:
Expand Down

0 comments on commit 1a98c3e

Please sign in to comment.