Skip to content

Commit

Permalink
[Easy] Include SPMD and DTensor files in UFMT checks (pytorch#98148)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#98148
Approved by: https://github.com/fegin
  • Loading branch information
mrshenli authored and pytorchmergebot committed Apr 2, 2023
1 parent 38609cc commit 0217982
Show file tree
Hide file tree
Showing 30 changed files with 285 additions and 386 deletions.
4 changes: 4 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,11 @@ include_patterns = [
'test/distributed/fsdp/**/*.py',
'torch/testing/_internal/common_fsdp.py',
'torch/distributed/_composable/**/*.py',
'torch/distributed/_spmd/**/*.py',
'torch/distributed/_tensor/**/*.py',
'test/distributed/_composable/**/*.py',
'test/distributed/_spmd/**/*.py',
'test/distributed/_tensor/**/*.py',
'torch/testing/_internal/common_dist_composable.py',
'test/test_value_ranges.py',
'torch/utils/_sympy/interp.py',
Expand Down
62 changes: 17 additions & 45 deletions test/distributed/_spmd/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
import torch.fx as fx
import torch.nn as nn
from torch.distributed._spmd.api import (
compile,
COMPILED_OBJECT_KEY,
Override,
Schema,
SPMD,
compile,
)
from torch.distributed._spmd.comm_tensor import CommTensor
from torch.distributed._tensor import DeviceMesh, Replicate
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed.distributed_c10d import get_global_rank, get_world_size
from torch.fx.experimental.proxy_tensor import make_fx
Expand Down Expand Up @@ -123,12 +123,8 @@ def fn(to_receive: torch.Tensor, to_scatter: List[torch.Tensor]):

# use a local_tensor + 1 for tracing to make sure that we are not
# simply replaying recorded tensor value
to_receive = torch.empty_like(
scattered_tensors[mesh.get_coordinate()[dim]]
)
traced_fn = make_fx(fn)(
to_receive, [t + 1 for t in scattered_tensors]
)
to_receive = torch.empty_like(scattered_tensors[mesh.get_coordinate()[dim]])
traced_fn = make_fx(fn)(to_receive, [t + 1 for t in scattered_tensors])

received_tensor = traced_fn(to_receive, scattered_tensors)
self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
Expand Down Expand Up @@ -162,9 +158,7 @@ def fn(gathered_list: List[torch.Tensor], tensor: torch.Tensor):

self.assertEqual(len(gathered_list), dim_group_size)
for idx, gathered_tensor in enumerate(gathered_list):
self.assertEqual(
gathered_tensor, torch.ones(3, 3) * global_ranks[idx]
)
self.assertEqual(gathered_tensor, torch.ones(3, 3) * global_ranks[idx])


class TraceDeviceMesh3DTest(DTensorTestBase, TraceDeviceMeshTestBase):
Expand Down Expand Up @@ -222,14 +216,10 @@ def _test_trace_replicate(self, model: nn.Module, x, *args, **kwargs):
spmd = SPMD(
deepcopy(model),
schema=Schema(
mesh=DeviceMesh(
self.device_type, torch.arange(self.world_size)
),
mesh=DeviceMesh(self.device_type, torch.arange(self.world_size)),
placements=[Replicate()],
),
input_schemas=kwargs["inp_schemas"]
if "inp_schemas" in kwargs
else None,
input_schemas=kwargs["inp_schemas"] if "inp_schemas" in kwargs else None,
)
if "inp_schemas" in kwargs:
del kwargs["inp_schemas"]
Expand All @@ -249,8 +239,7 @@ def _test_trace_replicate(self, model: nn.Module, x, *args, **kwargs):
# _Partial tensor shouldn't do that automatically. Hence explicitly
# do division here.
self.assertTrue(
p1.grad.allclose(p2.grad / self.world_size)
or p1.grad.allclose(p2.grad)
p1.grad.allclose(p2.grad / self.world_size) or p1.grad.allclose(p2.grad)
)

@with_comms
Expand All @@ -271,9 +260,7 @@ def forward(self, x):
inp_kwargs = {}
inp_kwargs["inp_schemas"] = [
Schema(
mesh=DeviceMesh(
self.device_type, torch.arange(self.world_size)
),
mesh=DeviceMesh(self.device_type, torch.arange(self.world_size)),
placements=[Replicate()],
)
]
Expand Down Expand Up @@ -331,9 +318,7 @@ def test_parallel(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.module_list = nn.ModuleList(
[nn.Linear(10, 10) for _ in range(2)]
)
self.module_list = nn.ModuleList([nn.Linear(10, 10) for _ in range(2)])

def forward(self, x):
return sum([m(x) for m in self.module_list])
Expand All @@ -359,9 +344,7 @@ def test_hybrid(self):
SPMD(
deepcopy(top_model),
schema=Schema(
mesh=DeviceMesh(
self.device_type, torch.arange(self.world_size)
),
mesh=DeviceMesh(self.device_type, torch.arange(self.world_size)),
placements=[Replicate()],
),
),
Expand All @@ -376,8 +359,7 @@ def test_hybrid(self):
# _Partial tensor shouldn't do that automatically. Hence explicitly
# do division here.
self.assertTrue(
p1.grad.allclose(p2.grad / self.world_size)
or p1.grad.allclose(p2.grad)
p1.grad.allclose(p2.grad / self.world_size) or p1.grad.allclose(p2.grad)
)


Expand All @@ -397,9 +379,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
positive = x[x >= 0]
negative = x[x < 0]

in_sizes = torch.tensor(
[positive.numel(), negative.numel()], dtype=torch.int32
)
in_sizes = torch.tensor([positive.numel(), negative.numel()], dtype=torch.int32)
out_sizes = torch.empty_like(in_sizes)
dist.all_to_all_single(
out_sizes,
Expand All @@ -409,9 +389,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)

xs = [positive, negative]
ys = [
torch.Tensor(out_sizes[i].item()) for i in range(out_sizes.numel())
]
ys = [torch.Tensor(out_sizes[i].item()) for i in range(out_sizes.numel())]
dist.all_to_all(ys, xs)

# some dummy compute
Expand Down Expand Up @@ -608,9 +586,7 @@ def test_train_step_override(self):
transform_targets = []

class DDMOverride(Override):
def replacement(
self, orig_submodule: torch.nn.Module
) -> torch.nn.Module:
def replacement(self, orig_submodule: torch.nn.Module) -> torch.nn.Module:
return DummyDDM()

def transform(
Expand All @@ -627,9 +603,7 @@ def transform(
# original logic, as we are testing the ability to
# modify graph after DTensor expansion.
with gm.graph.inserting_before(node):
new_node = gm.graph.call_function(
torch.add, args=node.args
)
new_node = gm.graph.call_function(torch.add, args=node.args)
node.replace_all_uses_with(new_node)

gm.graph.lint()
Expand Down Expand Up @@ -690,9 +664,7 @@ def train_step(mod, opt, inp):
self.assertEqual(graph_optimization.call_count, 1)
gm = train_step.__dict__[COMPILED_OBJECT_KEY].gm
train_step(mod, opt, inp)
self.assertEqual(
id(gm), id(train_step.__dict__[COMPILED_OBJECT_KEY].gm)
)
self.assertEqual(id(gm), id(train_step.__dict__[COMPILED_OBJECT_KEY].gm))
self.assertEqual(graph_optimization.call_count, 1)


Expand Down
1 change: 0 additions & 1 deletion test/distributed/_tensor/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_distribute_tensor(self):
shard_spec = [Shard(0)]

for requires_grad in [True, False]:

tensor_to_shard = torch.randn(
3 * self.world_size, 3, requires_grad=requires_grad
)
Expand Down
Loading

0 comments on commit 0217982

Please sign in to comment.