Skip to content

Commit

Permalink
[reland][dynamo] Better support for nn.Module (pytorch#88959)
Browse files Browse the repository at this point in the history
Relanding pytorch#88629

Pull Request resolved: pytorch#88959
Approved by: https://github.com/msaroufim
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Nov 13, 2022
1 parent 06ce133 commit e950afc
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 20 deletions.
127 changes: 127 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,133 @@ def forward(self, x):
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))


class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))

def forward(self, x):
return self.relu(self.linear(x) + self.buf0)


class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
def test_nn_module(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)

x = torch.randn(10, 10)
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
self.assertEqual(cnt.frame_count, 1)

def test_to(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
x = torch.randn(10, 10)
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
self.assertEqual(cnt.frame_count, 1)

# Ensure that there is no recompilation
opt_mod(x)
self.assertEqual(cnt.frame_count, 1)

opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
x = torch.randn(10, 10).to(dtype=torch.float64)
opt_mod(x)
# Ensure that there is a recompilation
self.assertEqual(cnt.frame_count, 2)

def test_attr(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))

def forward(self, x):
return self.r(torch.sin(x)) + self.buf0

mod = MockModule()
opt_mod = torch._dynamo.optimize("eager")(mod)

# Check parameteres and buffers
for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()):
self.assertTrue(id(p1) == id(p2))

def test_recursion(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)

for _ in range(5):
opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
opt_mod(torch.randn(10, 10))
self.assertEqual(cnt.frame_count, 1)

def test_composition(self):
class InnerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(torch.sin(x))

opt_inner_mod = InnerModule()

class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = opt_inner_mod

def forward(self, x):
return self.mod(torch.cos(x))

outer_mod = OuterModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)

x = torch.randn(4)
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
self.assertEqual(cnt.frame_count, 1)

def test_composition_with_opt_mod(self):
class InnerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(torch.sin(x))

inner_mod = InnerModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)

class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = opt_inner_mod

def forward(self, x):
return self.mod(torch.cos(x))

outer_mod = OuterModule()
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)

x = torch.randn(4)
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
# There will be a graph break for the inner mod being OptimizedModule
self.assertEqual(cnt.frame_count, 2)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
export,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
run,
skip,
Expand All @@ -25,6 +26,7 @@
"reset",
"list_backends",
"skip",
"OptimizedModule",
]


Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,16 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
"""
Check two models have same accuracy.
"""
from .eval_frame import OptimizedModule
from .testing import named_parameters_for_optimized_module
from .utils import same

if isinstance(gm, OptimizedModule):
gm.named_parameters = named_parameters_for_optimized_module(gm)

if isinstance(opt_gm, OptimizedModule):
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)

ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)

try:
Expand Down
74 changes: 54 additions & 20 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import sys
import textwrap
import threading
import traceback
import types
Expand Down Expand Up @@ -44,6 +45,27 @@
most_recent_backend = None


class OptimizedModule(torch.nn.Module):
"""
Wraps the original nn.Module object and later patches its
forward method to optimized self.forward method.
"""

def __init__(self, mod):
super().__init__()
# Installs the params/buffer
self._orig_mod = mod

def __getattr__(self, name):
if name == "_orig_mod":
return self._modules["_orig_mod"]
return getattr(self._orig_mod, name)

def forward(self, *args, **kwargs):
# This will be monkey patched later
raise RuntimeError("Should not be here")


def remove_from_cache(f):
"""
Make sure f.__code__ is not cached to force a recompile
Expand Down Expand Up @@ -118,31 +140,15 @@ def __call__(self, fn):
# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
mod = fn
optimized_forward = self(mod.forward)

class TorchDynamoNNModuleWrapper:
"""
A wrapper that redirects the forward call to the optimized
forward, while for rest it redirects the calls to the original
module.
"""

def __getattr__(self, name):
return getattr(mod, name)

def forward(self, *args, **kwargs):
return optimized_forward(*args, **kwargs)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

new_mod = TorchDynamoNNModuleWrapper()
new_mod = OptimizedModule(mod)
new_mod.forward = self(mod.forward)
# Save the function pointer to find the original callable while nesting
# of decorators.
new_mod._torchdynamo_orig_callable = mod
new_mod._torchdynamo_orig_callable = mod.forward
return new_mod

assert callable(fn)

callback = self.callback
on_enter = self.on_enter
backend_ctx_ctor = self.extra_ctx_ctor
Expand Down Expand Up @@ -184,6 +190,34 @@ def _fn(*args, **kwargs):
# If the function is called using torch._dynamo.optimize decorator, we
# should prevent any type of skipping.
if callback not in (None, False):
if not hasattr(fn, "__code__"):
raise RuntimeError(
textwrap.dedent(
"""
torch._dynamo.optimize is called on a non function object.
If this is a callable class, please optimize the individual methods that you are interested in optimizing.
>> class CallableClass:
>> def __init__(self):
>> super().__init__()
>> self.relu = torch.nn.ReLU()
>>
>> def __call__(self, x):
>> return self.relu(torch.sin(x))
>>
>> def print_hello(self):
>> print("Hello world")
>>
>> mod = CallableClass()
If you want to optimize the __call__ function
>> mod.__call__ = torch._dynamo.optimize(mod.__call__)
"""
)
)
always_optimize_code_objects[fn.__code__] = True

return _fn
Expand Down
14 changes: 14 additions & 0 deletions torch/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def clone_me(x):
return x.detach().clone().requires_grad_(x.requires_grad)


def named_parameters_for_optimized_module(mod):
assert isinstance(mod, eval_frame.OptimizedModule)
return mod._orig_mod.named_parameters


def remove_optimized_module_prefix(name):
prefix = "_orig_mod."
assert name.startswith(prefix)
name = name[len(prefix) :]
return torch.distributed.fsdp._common_utils.clean_tensor_name(name)


def collect_results(model, prediction, loss, example_inputs):
results = []
results.append(prediction)
Expand All @@ -44,6 +56,8 @@ def collect_results(model, prediction, loss, example_inputs):
grads = dict()
params = dict()
for name, param in model.named_parameters():
if isinstance(model, eval_frame.OptimizedModule):
name = remove_optimized_module_prefix(name)
param_copy = param
grad = param.grad
# Treat None and zero grad as same
Expand Down

0 comments on commit e950afc

Please sign in to comment.