Skip to content

Commit

Permalink
Log export result of torch.jit.trace to scuba (pytorch#126900)
Browse files Browse the repository at this point in the history
Summary: We want to track how well torch.jit.trace can be converted to export in large scale. As a first step, we log all of torch.jit.trace unittests whether we can convert the traced module to export module OR we can export the model directly

Test Plan: CI

Differential Revision: D57629682

Pull Request resolved: pytorch#126900
Approved by: https://github.com/SherlockNoMad
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed May 28, 2024
1 parent 3f79e09 commit 9521528
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 145 deletions.
11 changes: 11 additions & 0 deletions torch/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def log_torchscript_usage(api: str):
return


def check_if_torch_exportable():
return False


def log_torch_jit_trace_exportability(
api: str, type_of_export: str, export_outcome: str, result: str
):
_, _, _, _ = api, type_of_export, export_outcome, result
return


def export_api_rollout_check() -> bool:
return False

Expand Down
115 changes: 83 additions & 32 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,46 +943,95 @@ def wrapper(*args, **kwargs):
return wrapper


def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs):
if not isinstance(example_inputs, (tuple, list, dict)):
example_inputs = (example_inputs,)

elif isinstance(example_inputs, list):
example_inputs = tuple(example_inputs)

elif (
isinstance(example_inputs, (torch.Tensor, dict))
and example_kwarg_inputs is None
):
example_inputs = (example_inputs,)

if example_kwarg_inputs is None:
example_kwarg_inputs = {}
return example_inputs, example_kwarg_inputs


@contextmanager
def patch_forward(obj: torch.nn.Module, new_method):
"""Helper method to make it easier to cleanly torch.export() a method on a
module that is not `forward`.
"""
# Save the original method
original_method = obj.forward

# Patch the method
obj.forward = new_method.__get__(obj, obj.__class__)

try:
yield
finally:
# Restore the original method
obj.forward = original_method


@contextmanager
def _temp_disable_texpr_fuser():
original_state = torch._C._jit_texpr_fuser_enabled()
torch._C._jit_set_texpr_fuser_enabled(False)
try:
yield
finally:
torch._C._jit_set_texpr_fuser_enabled(original_state)

def process_trace_inputs_for_export(example_inputs, example_kwarg_inputs):
if not isinstance(example_inputs, tuple):
example_inputs = (example_inputs,)

if example_kwarg_inputs is None:
example_kwarg_inputs = {}
return example_inputs, example_kwarg_inputs
def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
with _temp_disable_texpr_fuser():

class _WrapperModule(torch.nn.Module):
def __init__(self, f):
super().__init__()
self.f = f
class _WrapperModule(torch.nn.Module):
def __init__(self, f):
super().__init__()
self.f = f

def forward(self, *args, **kwargs):
return self.f(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.f(*args, **kwargs)

from torch.jit._trace import TopLevelTracedModule
from torch.jit._trace import TopLevelTracedModule

export_args, export_kwargs = process_trace_inputs_for_export(args, kwargs)
export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs)

if isinstance(traced_callable, TopLevelTracedModule):
return _export(
traced_callable,
export_args,
export_kwargs,
strict=False,
_is_torch_jit_trace=True,
).module()
if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator]
return _export(
traced_callable,
export_args,
export_kwargs,
strict=False,
_is_torch_jit_trace=True,
).module()

else:
return _export(
_WrapperModule(traced_callable),
export_args,
export_kwargs,
strict=False,
_is_torch_jit_trace=True,
).module()
elif isinstance(traced_callable, torch.ScriptMethod) and isinstance(
traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator]
):
with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator]
return _export(
traced_callable.owner(), # type: ignore[operator]
export_args,
export_kwargs,
strict=False,
_is_torch_jit_trace=True,
).module()
else:
return _export(
_WrapperModule(traced_callable),
export_args,
export_kwargs,
strict=False,
_is_torch_jit_trace=True,
).module()


def _strict_export(
Expand Down Expand Up @@ -1412,7 +1461,9 @@ def _export(
),
len(export_graph_signature.input_specs),
)
combined_args = _combine_args(mod, args, kwargs)
combined_args = _combine_args(
mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace
)
range_constraints = make_constraints(
fake_mode,
gm,
Expand Down
Loading

0 comments on commit 9521528

Please sign in to comment.