Skip to content

Commit

Permalink
Add scuba logging for TorchScript usage (pytorch#121936)
Browse files Browse the repository at this point in the history
Summary: Infra to log live usage of TorchScript internally

Test Plan: manually tested

Differential Revision: D54923510

Pull Request resolved: pytorch#121936
Approved by: https://github.com/zhxchen17
  • Loading branch information
gmagogsfm authored and pytorchmergebot committed Mar 19, 2024
1 parent 4819da6 commit ba9a1d9
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/_jit_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torch._awaits import _Await
from torch._C import _Await as CAwait, Future as CFuture
from torch._sources import fake_range, get_source_lines_and_file, parse_def
from torch._utils_internal import log_torchscript_usage
from torch.futures import Future

IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
Expand Down Expand Up @@ -582,6 +583,7 @@ def unused_method(self, x):
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())
"""
log_torchscript_usage("export")
fn._torchscript_modifier = FunctionModifiers.EXPORT
return fn

Expand Down Expand Up @@ -623,6 +625,7 @@ def forward(self, x):
# exception raised
m(torch.rand(100))
"""
log_torchscript_usage("unused")
if isinstance(fn, property):
prop = fn
setattr( # noqa: B010
Expand Down Expand Up @@ -710,6 +713,7 @@ def forward(self, x):
import os
os.remove('m.pt')
"""
log_torchscript_usage("ignore")

if callable(drop):
# used without any args, so drop is actually a function
Expand Down
5 changes: 5 additions & 0 deletions torch/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def log_export_usage(**kwargs):
pass


def log_torchscript_usage(api: str):
_ = api
return


def justknobs_check(name: str) -> bool:
"""
This function can be used to killswitch functionality in FB prod,
Expand Down
3 changes: 3 additions & 0 deletions torch/jit/_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch._jit_internal as _jit_internal
from torch._classes import classes
from torch._jit_internal import _qualified_name
from torch._utils_internal import log_torchscript_usage
from torch.jit._builtins import _register_builtin
from torch.jit._fuser import _graph_for, _script_method_graph_for

Expand Down Expand Up @@ -1287,6 +1288,8 @@ def forward(self, a) -> MyModule:
if not _enabled:
return obj

log_torchscript_usage("script")

if optimize is not None:
warnings.warn(
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
Expand Down
4 changes: 4 additions & 0 deletions torch/jit/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""

import os

import torch
from torch._utils_internal import log_torchscript_usage
from torch.jit._recursive import wrap_cpp_module
from torch.serialization import validate_cuda_device

Expand Down Expand Up @@ -73,6 +75,7 @@ def forward(self, x):
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
"""
log_torchscript_usage("save")
if _extra_files is None:
_extra_files = {}
if isinstance(f, (str, os.PathLike)):
Expand Down Expand Up @@ -143,6 +146,7 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
import os
os.remove("scriptmodule.pt")
"""
log_torchscript_usage("load")
if isinstance(f, (str, os.PathLike)):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError(f"The provided filename {f} does not exist") # type: ignore[str-bytes-safe]
Expand Down
5 changes: 5 additions & 0 deletions torch/jit/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""

import contextlib

import copy
Expand All @@ -25,6 +26,8 @@
get_callable_argument_names,
is_scripting,
)

from torch._utils_internal import log_torchscript_usage
from torch.autograd import function
from torch.jit._script import _CachedForward, script, ScriptModule

Expand Down Expand Up @@ -803,6 +806,8 @@ def forward(self, x):
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
)

log_torchscript_usage("trace")

if isinstance(func, torch.jit.ScriptModule):
# it is hard to trace it because the forward method on ScriptModule is already defined, so it
# would result in an error.
Expand Down

0 comments on commit ba9a1d9

Please sign in to comment.