Skip to content

Commit

Permalink
update trec_shard + fully_shard composable test to include checkpoint…
Browse files Browse the repository at this point in the history
…ing (pytorch#1004)

Summary:
Pull Request resolved: pytorch#1004

as title, adds trec_shard + fully_shard unit test which tests checkpointing

Also removes trec_shard + FSDP test as this code path is not relevant anymore (now that fully_shard is in a better state)

Reviewed By: fegin, rohan-varma

Differential Revision: D42906963

fbshipit-source-id: f21e2f8083946993a942236f2e423fd6ee2e7657
  • Loading branch information
colin2328 authored and facebook-github-bot committed Feb 17, 2023
1 parent 997d40d commit e356c0c
Showing 1 changed file with 71 additions and 166 deletions.
237 changes: 71 additions & 166 deletions torchrec/distributed/composable/tests/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,9 @@
from torch.distributed.optim import (
_apply_optimizer_in_backward as apply_optimizer_in_backward,
)
from torchrec.distributed.shard import (
shard as trec_shard,
shard_modules as trec_shard_modules,
)
from torchrec.distributed.sharding_plan import (
apply_to_all,
construct_module_sharding_plan,
row_wise,
)
from torchrec.distributed.shard import shard as trec_shard
from torchrec.distributed.sharding_plan import row_wise
from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN
from torchrec.distributed.types import ShardingPlan
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
Expand All @@ -48,19 +40,14 @@
from torchrec.test_utils import skip_if_asan


class FSDPTest(unittest.TestCase):
class FullyShardTest(unittest.TestCase):
@classmethod
def _run(cls, path: str) -> None:
def _run(cls, param_path: str, opt_path: str) -> None: # noqa
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device: torch.device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)
device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl")
num_float_features = 32

tables = [
Expand All @@ -87,35 +74,32 @@ def _run(cls, path: str) -> None:
weighted_tables=weighted_tables,
dense_device=device,
)
plan = ShardingPlan(
plan={
"sparse.ebc": construct_module_sharding_plan(
m.sparse.ebc,
apply_to_all(m.sparse.ebc, row_wise()),
),
"sparse.weighted_ebc": construct_module_sharding_plan(
m.sparse.weighted_ebc,
apply_to_all(m.sparse.weighted_ebc, row_wise()),
),
}
)
apply_optimizer_in_backward(
RowWiseAdagrad,
m.sparse.parameters(),
{"lr": 0.01},
)
trec_shard_modules(
module=m,
m.sparse.ebc = trec_shard(
module=m.sparse.ebc,
device=device,
plan=plan,
plan=row_wise(),
)
sharded_m = FullyShardedDataParallel(
module=m,
device_id=rank,
ignored_modules=[m.sparse],
# TODO enable once works
# use_orig_params=True,
m.sparse.weighted_ebc = trec_shard(
module=m.sparse.weighted_ebc,
device=device,
plan=row_wise(),
)
m.dense = fully_shard(
m.dense,
device_id=device.index,
policy=ModuleWrapPolicy({nn.Linear}),
)
m.over = fully_shard(
m.over,
device_id=device.index,
policy=ModuleWrapPolicy({nn.Linear}),
)

dense_opt = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(m.named_parameters(), include=False)),
lambda params: torch.optim.Adam(
Expand All @@ -136,6 +120,7 @@ def _run(cls, path: str) -> None:
[
WarmupStage(
policy=WarmupPolicy.LINEAR,
max_iters=1000,
value=0.1,
lr_scale=1.0,
)
Expand All @@ -145,12 +130,14 @@ def _run(cls, path: str) -> None:
)
optims.append((name, warmup))
sparse_grad_parameter_names.add(name)
assert len(sparse_grad_parameter_names) == 5
fused_opt_scheduled = CombinedOptimizer(optims)
dense_opt_scheduled = WarmupOptimizer(
dense_opt,
[
WarmupStage(
policy=WarmupPolicy.LINEAR,
max_iters=1000,
value=0.15,
lr_scale=1.0,
)
Expand All @@ -164,10 +151,13 @@ def _run(cls, path: str) -> None:
# Runs a dummy optimizer step, which allows to initialize
# optimizer state, which is typically lazy.
# This allows us to do in-place loading of optimizer state from a checkpoint.
# Remark that fused optimizer needs speical case as its states are ShardedTensors.
# Remark that fused optimizer needs special case as its states are ShardedTensors.
# This is the reason we need to pass the sparse_grad_parameter_names as parameters.
opt.init_state(sparse_grad_parameter_names)
opt.save_param_groups(True)
model_param_names = set(dict(m.named_parameters()).keys())
opt_param_keys = set(opt.params.keys())
assert model_param_names.issubset(opt_param_keys)

######## run one iteration ########
_, local_batch = ModelInput.generate(
Expand All @@ -178,24 +168,23 @@ def _run(cls, path: str) -> None:
weighted_tables=weighted_tables,
)
batch = local_batch[0].to(device)
sharded_m(batch)[1].sum().backward()
m(batch)[1].sum().backward()
opt.step()

# TODO uncomment after fixing
# buffer = io.BytesIO()
# # Use FSDP state_dict() API instead of default
# opt_state_dict = FullyShardedDataParallel._optim_state_dict(sharded_m, opt)
# torch.save(opt_state_dict, buffer)
# buffer.seek(0)
state_dict = m.state_dict()
param_writer = FileSystemWriter(path=param_path)
param_reader = FileSystemReader(path=param_path)
save_state_dict(state_dict, param_writer)

writer = FileSystemWriter(path=path)
reader = FileSystemReader(path=path)
# TODO add StateDictType.SHARDED_STATE_DICT test
state_dict = sharded_m.state_dict()
save_state_dict(state_dict, writer)
# use FSDP.optim_state_dict() API
opt_state_dict = FullyShardedDataParallel.optim_state_dict(m, opt)
opt_writer = FileSystemWriter(path=opt_path)
opt_reader = FileSystemReader(path=opt_path)
# use Distributed checkpointing API
save_state_dict(opt_state_dict, opt_writer)

p_sum = torch.zeros(1, device=device)
for p in sharded_m.parameters():
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
Expand All @@ -206,7 +195,9 @@ def _run(cls, path: str) -> None:
assert p.sum() == 0
o_sum = torch.zeros(1, device=device)
for p_v in opt.state_dict()["state"].values():
for t in p_v.values():
for name, t in p_v.items():
if name == "step":
continue
if isinstance(t, ShardedTensor):
if not t.local_shards():
continue
Expand All @@ -215,13 +206,19 @@ def _run(cls, path: str) -> None:
t.zero_()
assert t.sum() == 0

state_dict = sharded_m.state_dict()
load_state_dict(state_dict, reader)
missing, unexpected = sharded_m.load_state_dict(state_dict)
load_state_dict(state_dict, param_reader)
missing, unexpected = m.load_state_dict(state_dict)
assert len(missing) == 0 and len(unexpected) == 0

load_state_dict(opt_state_dict, opt_reader)
# use FSDP.optim_state_dict_to_load() API
new_opt_state_dict = FullyShardedDataParallel.optim_state_dict_to_load(
opt_state_dict, m, opt, is_named_optimizer=True
)
opt.load_state_dict(new_opt_state_dict)

p_sum_loaded = torch.zeros(1, device=device)
for p in sharded_m.parameters():
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
Expand All @@ -230,118 +227,26 @@ def _run(cls, path: str) -> None:
p_sum_loaded += p.sum()
assert p_sum.allclose(p_sum_loaded)

# Use FSDP load_state_dict() API instead of default
# TODO uncomment after fixing
# FullyShardedDataParallel._load_optim_state_dict_pre_hook(sharded_m, opt, torch.load(buffer))
# o_sum_loaded = torch.zeros(1, device=device)
# for p_v in opt.state_dict()["state"].values():
# for t in p_v.values():
# if isinstance(t, ShardedTensor):
# if not t.local_shards():
# continue
# t = t.local_tensor()
# o_sum_loaded += t.sum()
# assert o_sum.allclose(o_sum_loaded)

@skip_if_asan
# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_checkpoint(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as path:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file"},
start_method="spawn",
monitor_interval=1,
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=self._run)(path)


class FSDPTestComposable(unittest.TestCase):
@classmethod
def _run(cls) -> None:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl")
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=device,
)
m.sparse.ebc = trec_shard(
module=m.sparse.ebc,
device=device,
plan=row_wise(),
)
m.sparse.weighted_ebc = trec_shard(
module=m.sparse.weighted_ebc,
device=device,
plan=row_wise(),
)
m.dense = fully_shard(
m.dense,
device_id=device.index,
policy=ModuleWrapPolicy({nn.Linear}),
)
m.over = fully_shard(
m.over,
device_id=device.index,
policy=ModuleWrapPolicy({nn.Linear}),
)

######## run one iteration ########
_, local_batch = ModelInput.generate(
batch_size=8,
world_size=world_size,
num_float_features=num_float_features,
tables=tables,
weighted_tables=weighted_tables,
)
batch = local_batch[0].to(device)
m(batch)[1].sum().backward()
# TODO add checkpointing test once fully_shard supports
# TODO add apply_optimizer_in_backward() API and optimizer state checkpoint
o_sum_loaded = torch.zeros(1, device=device)
for p_v in opt.state_dict()["state"].values():
for name, t in p_v.items():
if name == "step":
continue
if isinstance(t, ShardedTensor):
if not t.local_shards():
continue
t = t.local_tensor()
o_sum_loaded += t.sum()
assert o_sum.allclose(o_sum_loaded)

@skip_if_asan
# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_composable_forward(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
def test_composable_checkpoint(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as param_path, tempfile.TemporaryDirectory() as opt_path:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
Expand All @@ -354,4 +259,4 @@ def test_composable_forward(self) -> None:
monitor_interval=1,
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=self._run)()
elastic_launch(config=lc, entrypoint=self._run)(param_path, opt_path)

0 comments on commit e356c0c

Please sign in to comment.