Skip to content

Commit

Permalink
[SPMD] Enable fused Adam in full train step tracing (pytorch#98113)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrshenli authored and pytorchmergebot committed Apr 1, 2023
1 parent bccf2ef commit e8d3960
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 24 deletions.
30 changes: 22 additions & 8 deletions test/distributed/_spmd/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,17 +558,14 @@ def train_step(mod, opt, inp):
torch.manual_seed(1)
# FIXME(@mrshenli): gradients for bias is missing
mod = nn.Linear(10, 10, bias=True).cuda(rank)
# FIXME(@mrshenli): we have to enable foreach to get better perf
opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=True)
inp = torch.randn(2, 10).cuda(rank)

ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.01, foreach=True)
self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step)

@skip_if_lt_x_gpu(2)
@with_comms
def test_adam(self):
def _test_adam(self, *, foreach: bool, fused: bool):
@compile()
def train_step(mod, opt, inp):
mod(inp).sum().backward()
Expand All @@ -580,16 +577,31 @@ def train_step(mod, opt, inp):
torch.manual_seed(0)
# FIXME(@mrshenli): gradients for bias is missing
mod = nn.Linear(10, 10, bias=False).cuda(rank)
# FIXME(@mrshenli): we have to enable foreach to get better perf
opt = torch.optim.Adam(
mod.parameters(), lr=0.01, foreach=True, capturable=True
mod.parameters(),
lr=0.01,
foreach=foreach,
fused=fused,
capturable=True,
)
inp = torch.randn(2, 10).cuda(rank)

ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
ddp_opt = torch.optim.Adam(ddp_mod.parameters(), lr=0.01, foreach=True)
ddp_opt = torch.optim.Adam(
ddp_mod.parameters(), lr=0.01, foreach=foreach, fused=fused
)
self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step)

@skip_if_lt_x_gpu(2)
@with_comms
def test_adam_foreach(self):
self._test_adam(foreach=True, fused=False)

@skip_if_lt_x_gpu(2)
@with_comms
def test_adam_fused(self):
self._test_adam(foreach=False, fused=True)

@skip_if_lt_x_gpu(2)
@with_comms
def test_train_step_override(self):
Expand Down Expand Up @@ -678,7 +690,9 @@ def train_step(mod, opt, inp):
self.assertEqual(graph_optimization.call_count, 1)
gm = train_step.__dict__[COMPILED_OBJECT_KEY].gm
train_step(mod, opt, inp)
self.assertEqual(id(gm), id(train_step.__dict__[COMPILED_OBJECT_KEY].gm))
self.assertEqual(
id(gm), id(train_step.__dict__[COMPILED_OBJECT_KEY].gm)
)
self.assertEqual(graph_optimization.call_count, 1)


Expand Down
26 changes: 26 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,32 @@ def meta__foreach_pow_scalar_and_tensor(self, exponent):
return [torch.empty_like(e) for e in exponent]


@register_meta([aten._fused_adam_.default])
def meta__fused_adam_(
self,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
*,
lr,
beta1,
beta2,
weight_decay,
eps,
amsgrad,
maximize,
grad_scale=None,
found_inf=None,
):
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
check(
isinstance(self, List),
lambda: f"exponent must be a tensor list but got {type(self)}",
)


@register_meta([aten._int_mm])
@out_wrapper()
def meta__int_mm(a, b):
Expand Down
43 changes: 43 additions & 0 deletions torch/distributed/_spmd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,48 @@ def _foreach_addcop_scalar_decomp(op, self, tensor1, tensor2, scalar=1):
s.copy_(s_u)


def _fused_adam_decomp(
self,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
*,
lr=1,
beta1=1,
beta2=1,
weight_decay=1,
eps=1,
amsgrad=True,
maximize=True,
grad_scale=None,
found_inf=None,
):
orig_tuple = (self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs)
updated_tuple = aten._fused_adam.default(
self,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
amsgrad=amsgrad,
maximize=maximize,
grad_scale=grad_scale,
found_inf=found_inf,
)

for orig, updated in zip(orig_tuple, updated_tuple):
for o, u in zip(orig, updated):
o.copy_(u)


FOREACH_DECOMP_TABLE = {
aten._foreach_add_.List: _foreach_add_decomp,
aten._foreach_add_.Scalar: partial(
Expand All @@ -280,6 +322,7 @@ def _foreach_addcop_scalar_decomp(op, self, tensor1, tensor2, scalar=1):
aten._foreach_sub_.Scalar: partial(
_foreach_binop_scalar_decomp, aten._foreach_sub.Scalar
),
aten._fused_adam_.default: _fused_adam_decomp,
}


Expand Down
9 changes: 7 additions & 2 deletions torch/distributed/_spmd/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,13 @@ def _rebuild_graph(
if all(
[
isinstance(n.target, torch._ops.OpOverload)
and n.target._schema.name.startswith(
"aten::_foreach"
and (
n.target._schema.name.startswith(
"aten::_foreach"
)
or n.target._schema.name.startswith(
"aten::_fused_adam"
)
)
for n in [dtn, node]
]
Expand Down
45 changes: 44 additions & 1 deletion torch/distributed/_spmd/experimental_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Optional, Sequence
from typing import List, Optional, Sequence, Tuple

import torch

Expand Down Expand Up @@ -147,6 +147,49 @@ def _prop__foreach_pow_scalar_and_tensor(op_schema: OpSchema):
return OutputSharding(output_spec=exponent)


@register_prop_rule([aten._fused_adam.default]) # pyre-ignore
def _prop__fused_adam(op_schema: OpSchema):
NT = 5
tesnor_list_args: Tuple[List[DTensorSpec]] = op_schema.args_schema[:NT] # type: ignore[assignment]

assert all([isinstance(schema, list) for schema in tesnor_list_args])
assert all(
[
isinstance(s, DTensorSpec)
for schema in tesnor_list_args
for s in schema
]
)

tensor_schemas: Tuple[List[DTensorSpec]] = [ # type: ignore[assignment]
schema for schema in tesnor_list_args if len(schema)
]

assert all([len(s) == len(tensor_schemas[0]) for s in tensor_schemas]), (
"expect the same number of gradients and states, but got "
f"{[len(s) for s in tensor_schemas]}."
)

if any([any([t != ts[0] for t in ts]) for ts in zip(*tensor_schemas)]):
new_schemas: Tuple[List[DTensorSpec]] = tuple( # type: ignore[assignment]
op_schema.args_schema[0] if len(s) else s for s in tesnor_list_args
)
return OutputSharding(
output_spec=None,
schema_suggestions=[
OpSchema(
func_schema=op_schema.func_schema,
args_schema=new_schemas + op_schema.args_schema[NT:],
kwargs_schema=op_schema.kwargs_schema,
is_inplace=op_schema.is_inplace,
is_out_variant=op_schema.is_out_variant,
)
],
)
else:
return OutputSharding(output_spec=(op_schema.args_schema[0],) * NT) # type: ignore[arg-type]


@register_prop_rule(aten.native_layer_norm.default) # pyre-ignore
def _prop_native_layer_norm(op_schema: OpSchema) -> OutputSharding:
input, normalized_shape, weight, bias, eps = op_schema.args_schema
Expand Down
23 changes: 10 additions & 13 deletions torch/distributed/_tensor/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def wrap(res: object, spec: OutputSpecType) -> object:
if isinstance(res, torch.Tensor):
def to_dt(res, spec):
assert spec is not None and isinstance(
spec, DTensorSpec
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
Expand All @@ -39,6 +39,9 @@ def wrap(res: object, spec: OutputSpecType) -> object:
requires_grad=res.requires_grad,
stride=spec.tensor_meta.stride,
)

if isinstance(res, torch.Tensor):
return to_dt(res, spec)
elif isinstance(res, (list, tuple)):
assert spec is not None and isinstance(
spec, (list, tuple)
Expand All @@ -48,21 +51,15 @@ def wrap(res: object, spec: OutputSpecType) -> object:
# NOTE: local results might return Optional Tensor from ATen op, so we need
# to handle that case and make sure we don't wrap None with DTensor.
# (i.e. native_layer_norm.backward)
if e is not None and s is not None:
assert s.tensor_meta is not None
res_dt = dtensor.DTensor(
e,
s.mesh,
s.placements,
shape=s.tensor_meta.shape,
dtype=s.tensor_meta.dtype,
requires_grad=s.tensor_meta.requires_grad,
stride=s.tensor_meta.stride,
if isinstance(e, (list, tuple)) and isinstance(s, (list, tuple)):
res_list.append(
type(e)([to_dt(ee, ss) for ee, ss in zip(e, s)])
)
elif e is not None and s is not None:
res_list.append(to_dt(e, s))
else:
res_dt = None
res_list.append(None) # type: ignore[arg-type]

res_list.append(res_dt)
return tuple(res_list) if isinstance(res, tuple) else res_list
else:
# if the res contains only non tensor values, we simply return it without rewrapping
Expand Down

0 comments on commit e8d3960

Please sign in to comment.