Skip to content

Commit

Permalink
Make allreduce compatible with fx ProxyTensor (pytorch#84126)
Browse files Browse the repository at this point in the history
land after pytorch#83122

This PR explores solutions for 2 issues:

1. Collective comm ops are inplace ops, and does not return a tensor.
   With that, `make_fx` cannot include comm ops in the traced graph.
   The current solution is to make comm ops return a tuple of
   `(output_tensors, work_handle)`, so that
   [`proxy_call`](https://github.com/pytorch/pytorch/blob/90821aab100a436424113e2306eac63f5e247ee5/torch/fx/experimental/proxy_tensor.py#L170-L172)
   can handle that. It won't change the behavior of existing c10d
   Python/C++ APIs, so I directly added the code to `Ops.cpp`.
2. `make_fx` does not recognize `ProcessGroup::Work` and will ignore
   the `wait()` call on the work when tracing graph. However, this
   might break correctness, as when running the traced function, it
   could consume a tensor before it's ready. The current solution
   is to create a `CommTensor` tensor subclass to explicitly call
   `wait()`. In this PR, I am only doing this in the test, as we
   will need more discussion to see if we can add this to c10d Python
   implementations. kudos to @Chillee @wanchaol
Pull Request resolved: pytorch#84126
Approved by: https://github.com/wanchaol
  • Loading branch information
mrshenli authored and pytorchmergebot committed Aug 26, 2022
1 parent f93446a commit ec5b83f
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 8 deletions.
283 changes: 280 additions & 3 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import tempfile
import threading
import time
from contextlib import suppress
from dataclasses import dataclass
from datetime import timedelta
from itertools import product
from sys import platform
from contextlib import suppress

import torch
import torch.distributed as dist
Expand All @@ -19,17 +20,24 @@
sys.exit(0)

import torch.distributed.distributed_c10d as c10d
from torch.utils.checkpoint import checkpoint
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
import torch.nn.functional as F
import torch.testing._internal.common_utils as common
from torch import nn
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import (
_ProxyTensor,
fetch_tensor_proxy,
get_proxy_slots,
make_fx,
set_proxy_slot,
track_tensor_tree,
)
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)

from torch.testing._internal.common_utils import (
TestCase,
load_tests,
Expand All @@ -38,6 +46,10 @@
instantiate_parametrized_tests,
parametrize
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import tree_map, tree_map_only
from torch.utils.checkpoint import checkpoint


if TEST_WITH_DEV_DBG_ASAN:
print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr)
Expand Down Expand Up @@ -1326,6 +1338,271 @@ def test_send_recv(self):
instantiate_parametrized_tests(CommonDistributedDataParallelTest)


def wait_comm(comm_result):
# This function is only used by tracing mode as a call_function node right
# before consuming a collective result tensor.
comm_result._work.wait()
return comm_result._tensor


@dataclass
class CommResult:
# a custom type wrapping both inplace output tensor and work handle
_tensor: torch.Tensor
_work: torch.classes.c10d.Work


def wrap_comm_result(result):
# allreduce_ returns ([tensor], work)
tensor = result[0][0]
work = result[1]
return ([CommResult(result[0][0], result[1])], result[1])


class CommTensor(torch.Tensor):
r"""
A Tensor subclass to wrap input tensors for collective communications. This
Tensor subclass works for both eager and tracing mode.
In eager mode, it will record whether the inplace collective communication
has been launched using this Tensor and remember the corresponding work
handle. If yes, it will expliclty call wait() in the ``__torch_dispatch__``
function before subsequent operations consuming the value of the Tensor.
In tracing mode, ``CommTensor`` inserts two node into the graph using the
``__torch_dispatch__`` function.
1. The first node is inserted right after the
communication, wrapping both the inplace output tensor and the returned
work handle into a custom CommResult type. We have to do this because
``ProxyTorchDispatchMode`` only handles ``torch.Tensor``, ``_ProxyTensor``,
and ``torch.nn.Parameter`` objects and will treat the work handle
as a constant and embed that into the graph. As a result, during execution,
it will use the work handle created during tracing and will lead to wrong
result. The solution in this test is to manually create a proxy on the
return value of ``allreduce_`` which is ``([tensor], work)``, and wrap that
to ``[(CommResult(tensor, work)), work]``. In this way, subsequent nodes can
directly consume ``CommResult``.
2. The second node is inserted right before any subsequent node reads from
``CommResult``. It will call ``wait()`` on the stashed work handle to ensure
that computation waits for communication.
It is specifically tailored for allreduce_ at the moment.
"""
@staticmethod
def __new__(cls, tensor: torch.Tensor):
r = torch.Tensor._make_subclass( # type: ignore[attr-defined]
cls,
tensor,
require_grad=tensor.requires_grad,
)
# The tensor object wrapped by this CommTensor
r._tensor: torch.Tensor = tensor
# Record whether communication has launched on this tensor.
r._after_comm: bool = False
return r

def __repr__(self):
return f"CommTensor({self._tensor}, after_comm={self._after_comm})"

# disable __torch_function__ so that CommTensor can recursively dispatch
# with ProxyTorchDispatchMode in make_fx
__torch_function__ = _disabled_torch_function_impl

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# shared states when unwrapping args
tracer = None
after_comm = False

def get_tracer(obj):
slots = get_proxy_slots(obj)
if slots is None:
return None
keys = tuple(slots.keys())
assert len(keys) == 1
return keys[0]

def get_proxy(obj):
slots = get_proxy_slots(obj)
if slots is None:
return None
vals = tuple(slots.values())
assert len(vals) == 1
return vals[0]

# wrapped ._tensor if this is a CommTensor, and insert/call wait()
# if communication has been launched on this tensor.
def unwrap(e):
if isinstance(e, CommTensor):
nonlocal tracer, after_comm

after_comm = e._after_comm
tracer = get_tracer(e._tensor)

if after_comm:
if tracer is not None:
# insert a node to the traced graph.
proxy_res = tracer.create_proxy(
'call_function',
wait_comm,
(get_proxy(e._tensor).proxy,),
{},
name="wait_comm"
)
# HACK: update the proxy for the inplace output
set_proxy_slot(e._tensor, tracer, proxy_res)
# For eager mode, simply wait.
# During tracing, still need to wait here, to make sure the
# execution during tracing is correct.
e._work.wait()


return e._tensor
else:
return e

unwrapped_args = tree_map(unwrap, args)
unwrapped_kwargs = tree_map(unwrap, kwargs)

if "allreduce_" in func.__name__:
if tracer is not None:
# in tracing mode, get proxies for args
proxy_args, proxy_kwargs = tree_map_only(
_ProxyTensor,
lambda e: e.proxy,
tree_map_only(
torch.Tensor,
fetch_tensor_proxy(tracer),
(unwrapped_args, unwrapped_kwargs)
),
)

# get proxy for output tuple
proxy_res = func(*proxy_args, **proxy_kwargs)
# insert a node that wraps the output tuple into
# CommResult(tensor, work)
comm_result_proxy = tracer.create_proxy(
'call_function',
wrap_comm_result,
(proxy_res, ),
{},
name="comm_result"
)

with no_dispatch():
# disable dispatch to avoid trigger ProxyTorchDispatchMode logic
out = func(*unwrapped_args, **unwrapped_kwargs)

# wrap output with the proxy of CommResult, so that subsequent
# ops and link to it.
track_tensor_tree(out, comm_result_proxy, constant=None, tracer=tracer)

# N.B.: we still need to remember the work handle here, and wait
# for it later to make sure the execution during tracing is
# correct.
args[0][0]._work = out[1]
# remember comm is already launched
args[0][0]._after_comm = True

# HACK: update the proxy on the input argument as this is an
# inplace collective communication.
set_proxy_slot(unwrapped_args[0][0], tracer, get_proxy(out[0][0]))
return out
else:
# in eager mode, simply remember work handle as an attribute
out = func(*unwrapped_args, **kwargs)
args[0][0]._work = out[1]
args[0][0]._after_comm = True
return out
else:
if after_comm:
return func(*unwrapped_args, **unwrapped_kwargs)
else:
# we need to propagate CommTensor wrapping until the first
# subsequent operation has waited for it.
return CommTensor(func(*unwrapped_args, **unwrapped_kwargs))


class CompilerTest(MultiProcessTestCase):
def setUp(self):
super(CompilerTest, self).setUp()
self._spawn_processes()

def tearDown(self):
super(CompilerTest, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass

def _get_process_group(self):
raise NotImplementedError("To be implemented by subclass")

def _test_work_wait(self, x: torch.Tensor):
pg = self._get_default_group()

def fn(x: torch.Tensor) -> torch.Tensor:
# N.B.: explicitly wrapping with CommTensor instead of updating
# all_reduce Python implementation, as the later will need more
# discussion.
y = CommTensor(x + x)
work = dist.all_reduce(y, group=pg, async_op=True)
# this wait() will be ignored in tracing mode as
# ProxyTorchDispatchMode only supports torch.Tensor, _ProxyTensor,
# and torch.nn.Parameter objects
work.wait()
return y * 2

xx = x.clone()

# trace fn into a GraphModule
traced_fn = make_fx(fn)(xx)
traced_fn.graph.lint()
traced_fn.graph.eliminate_dead_code()

if self.rank == 0:
traced_fn.graph.print_tabular()

# make sure the mul op indeed waits for comm
for node in traced_fn.graph.nodes:
if node.op == "call_function" and "mul.Tensor" in node.target.__name__:
prev = node.args[0]
curr = None
waited = False
commed = False
while prev is not None and not commed:
curr = prev
waited |= all([
curr.op == "call_function",
curr.target == wait_comm,
])
commed |= all([
curr.op == "call_function",
"allreduce_" in curr.target.__name__
])

prev = curr.args[0]

self.assertTrue(waited)
self.assertTrue(commed)

# Update input to make sure we are not recording it as constant during
# tracing.
x += 1
xx += 1

y = fn(x)
yy = traced_fn(xx)

# check correctness
self.assertEqual(y, yy)

xx += 1
yy = traced_fn(xx)
self.assertFalse(y.allclose(yy))


if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down
18 changes: 18 additions & 0 deletions test/distributed/test_c10d_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2339,6 +2339,24 @@ def test_gloo_warn_not_in_group(self):
self._test_warn_not_in_group(backend="gloo")


class CompilerTest(test_c10d_common.CompilerTest):

@property
def world_size(self):
return 2

def _get_default_group(self):
store = c10d.FileStore(self.file_name, self.world_size)
return c10d.ProcessGroupGloo(store, self.rank, self.world_size)

def test_work_wait_cpu(self):
self._test_work_wait(torch.ones(2, 2) * self.rank)

@skip_if_lt_x_gpu(2)
def test_work_wait_gpu(self):
self._test_work_wait(torch.ones(2, 2, device=self.rank) * self.rank)


if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down
16 changes: 16 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,22 @@ def test_nccl_warn_not_in_group_debug_info(self):
def test_nccl_warn_not_in_group_debug_off(self):
self._test_warn_not_in_group(backend="nccl")


class CompilerTest(test_c10d_common.CompilerTest):

@property
def world_size(self):
return 2

def _get_default_group(self):
store = c10d.FileStore(self.file_name, self.world_size)
return c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

@skip_if_lt_x_gpu(2)
def test_work_wait_gpu(self):
self._test_work_wait(torch.ones(2, 2, device=self.rank) * self.rank)


if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@
("aten::unsafe_split_with_sizes.out", datetime.date(2022, 9, 1)),
("aten::vsplit.array", datetime.date(2022, 9, 1)),
("aten::vsplit.int", datetime.date(2022, 9, 1)),
("c10d::allreduce_", datetime.date(2022, 10, 1)),
]

ALLOW_LIST_COMPILED = [
Expand Down
Loading

0 comments on commit ec5b83f

Please sign in to comment.