forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_functionalize.py
45 lines (32 loc) · 1.26 KB
/
test_functionalize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Owner(s): ["module: functorch"]
import functorch
from unittest.mock import patch
import functools
from torch.testing._internal.common_utils import run_tests, skipIfRocm
import test_aotdispatch
def make_functionalize_fn(fn):
@functools.wraps(fn)
def _fn(*args, **kwargs):
with patch.object(functorch.compile.config, "use_functionalize", True):
return fn(*args, **kwargs)
return _fn
def make_functionalize_test(cls):
class FunctionalizeTest(cls):
pass
FunctionalizeTest.__name__ = f"Functionalize{cls.__name__}"
for name in dir(cls):
if name.startswith("test_"):
fn = getattr(cls, name)
if not callable(fn):
continue
new_name = f"{name}_functionalize"
fn = make_functionalize_fn(fn)
fn.__name__ = new_name
setattr(FunctionalizeTest, name, None)
setattr(FunctionalizeTest, new_name, fn)
# https://github.com/pytorch/pytorch/issues/96560
return skipIfRocm(FunctionalizeTest)
FunctionalizeTestPythonKeyAOT = make_functionalize_test(test_aotdispatch.TestAOTAutograd)
FunctionalizeTestPythonKeyPartitioning = make_functionalize_test(test_aotdispatch.TestPartitioning)
if __name__ == "__main__":
run_tests()