Skip to content

Commit

Permalink
[torch] Use __prepare_scriptable__ for closures (pytorch#121553)
Browse files Browse the repository at this point in the history
Summary:
This fixes a case left incomplete by pytorch#106229
The object is using __prepare_scriptable__ correctly inside of torch.jit.script()
but the clousre that is obtained below is using the non-prepared version.
This causes issues when the prepared and non-prepared versions are in different python modules.

Test Plan:
```
buck2 run mode/opt caffe2/test:jit -- -r test_decorator
```

Differential Revision: D54308741

Re-exporting, as pytorch#120806 pytorch#121307 were not properly merged.

Co-authored-by: Daniel Herrera <[email protected]>
Pull Request resolved: pytorch#121553
Approved by: https://github.com/huydhn, https://github.com/seemethere
  • Loading branch information
dherrera-meta authored and pytorchmergebot committed Mar 11, 2024
1 parent b4160fd commit dccc1ca
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 0 deletions.
20 changes: 20 additions & 0 deletions test/jit/mydecorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
r"""
Decorator used in test_decorator.py. We define it in a
separate file on purpose to test that the names in different modules
are resolved correctly.
"""

import functools


def my_decorator(func):
"""Dummy decorator that removes itself when torchscripting"""

@functools.wraps(func)
def wrapped_func(*args, **kwargs):
return func(*args, **kwargs)

# torch.jit.script() uses __prepare_scriptable__ to remove the decorator
wrapped_func.__prepare_scriptable__ = lambda: func

return wrapped_func
13 changes: 13 additions & 0 deletions test/jit/myfunction_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Helper function used in test_decorator.py. We define it in a
separate file on purpose to test that the names in different modules
are resolved correctly.
"""

from jit.mydecorator import my_decorator
from jit.myfunction_b import my_function_b


@my_decorator
def my_function_a(x: float) -> float:
return my_function_b(x) + 1
16 changes: 16 additions & 0 deletions test/jit/myfunction_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
r"""
Helper function used in test_decorator.py. We define it in a
separate file on purpose to test that the names in different modules
are resolved correctly.
"""

from jit.mydecorator import my_decorator


@my_decorator
def my_function_b(x: float) -> float:
return my_function_c(x) + 2


def my_function_c(x: float) -> float:
return x + 3
27 changes: 27 additions & 0 deletions test/jit/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Owner(s): ["oncall: jit"]
# flake8: noqa

import sys
import unittest
from enum import Enum
from typing import List, Optional

import torch
from torch.testing._internal.jit_utils import JitTestCase

from jit.myfunction_a import my_function_a


class TestDecorator(JitTestCase):
def test_decorator(self):
# Note: JitTestCase.checkScript() does not work with decorators
# self.checkScript(my_function_a, (1.0,))
# Error:
# RuntimeError: expected def but found '@' here:
# @my_decorator
# ~ <--- HERE
# def my_function_a(x: float) -> float:
# Do a simple torch.jit.script() test instead
fn = my_function_a
fx = torch.jit.script(fn)
self.assertEqual(fn(1.0), fx(1.0))
4 changes: 4 additions & 0 deletions torch/jit/_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,10 @@ def try_compile_fn(fn, loc):
f"Consider manually annotating `{fn}` with @torch.jit.script."
)

# The object returned by __prepare_scriptable__ might have a different closure.
# Resolve it here to get the right resolution callback.
fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator]

# We don't have the actual scope where the function was defined, but we can
# extract the necessary info from the closed over variables on the function
# object
Expand Down

0 comments on commit dccc1ca

Please sign in to comment.