forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torch] Use __prepare_scriptable__ for closures (pytorch#121553)
Summary: This fixes a case left incomplete by pytorch#106229 The object is using __prepare_scriptable__ correctly inside of torch.jit.script() but the clousre that is obtained below is using the non-prepared version. This causes issues when the prepared and non-prepared versions are in different python modules. Test Plan: ``` buck2 run mode/opt caffe2/test:jit -- -r test_decorator ``` Differential Revision: D54308741 Re-exporting, as pytorch#120806 pytorch#121307 were not properly merged. Co-authored-by: Daniel Herrera <[email protected]> Pull Request resolved: pytorch#121553 Approved by: https://github.com/huydhn, https://github.com/seemethere
- Loading branch information
1 parent
b4160fd
commit dccc1ca
Showing
5 changed files
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
r""" | ||
Decorator used in test_decorator.py. We define it in a | ||
separate file on purpose to test that the names in different modules | ||
are resolved correctly. | ||
""" | ||
|
||
import functools | ||
|
||
|
||
def my_decorator(func): | ||
"""Dummy decorator that removes itself when torchscripting""" | ||
|
||
@functools.wraps(func) | ||
def wrapped_func(*args, **kwargs): | ||
return func(*args, **kwargs) | ||
|
||
# torch.jit.script() uses __prepare_scriptable__ to remove the decorator | ||
wrapped_func.__prepare_scriptable__ = lambda: func | ||
|
||
return wrapped_func |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" | ||
Helper function used in test_decorator.py. We define it in a | ||
separate file on purpose to test that the names in different modules | ||
are resolved correctly. | ||
""" | ||
|
||
from jit.mydecorator import my_decorator | ||
from jit.myfunction_b import my_function_b | ||
|
||
|
||
@my_decorator | ||
def my_function_a(x: float) -> float: | ||
return my_function_b(x) + 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
r""" | ||
Helper function used in test_decorator.py. We define it in a | ||
separate file on purpose to test that the names in different modules | ||
are resolved correctly. | ||
""" | ||
|
||
from jit.mydecorator import my_decorator | ||
|
||
|
||
@my_decorator | ||
def my_function_b(x: float) -> float: | ||
return my_function_c(x) + 2 | ||
|
||
|
||
def my_function_c(x: float) -> float: | ||
return x + 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Owner(s): ["oncall: jit"] | ||
# flake8: noqa | ||
|
||
import sys | ||
import unittest | ||
from enum import Enum | ||
from typing import List, Optional | ||
|
||
import torch | ||
from torch.testing._internal.jit_utils import JitTestCase | ||
|
||
from jit.myfunction_a import my_function_a | ||
|
||
|
||
class TestDecorator(JitTestCase): | ||
def test_decorator(self): | ||
# Note: JitTestCase.checkScript() does not work with decorators | ||
# self.checkScript(my_function_a, (1.0,)) | ||
# Error: | ||
# RuntimeError: expected def but found '@' here: | ||
# @my_decorator | ||
# ~ <--- HERE | ||
# def my_function_a(x: float) -> float: | ||
# Do a simple torch.jit.script() test instead | ||
fn = my_function_a | ||
fx = torch.jit.script(fn) | ||
self.assertEqual(fn(1.0), fx(1.0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters