Skip to content

Commit

Permalink
Add API to set a module as a leaf node when recursively setting Z3 ho…
Browse files Browse the repository at this point in the history
…oks (microsoft#4966)

ZeRO3 does not work with MoE models because the order of executing
modules can change at every forward/backward pass (microsoft#4094, microsoft#4808).

This PR adds an API to stop breaking down a module for parameter
fetching. The following shows an example of the usage:
```python
import torch
import deepspeed
import deepspeed.comm as dist
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

model_id = "mistralai/Mixtral-8x7B-v0.1"
ds_config = {
      "bf16": {
          "enabled": True,
      },
      "zero_optimization": {
          "stage": 3,
      },
      "train_micro_batch_size_per_gpu": 1,
  }

hfdsc = HfDeepSpeedConfig(ds_config)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
model.eval()

ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()
model = ds_engine.module

inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to("cuda")
outputs = model.generate(inputs, max_new_tokens=200)
output_str = tokenizer.decode(outputs[0])
if dist.get_rank() == 0:
  print(f"output: {output_str}")
```

By passing names of modules to `set_z3_leaf_modules`, DeepSpeed engine
stops breaking down the module.

In this example, `MixtralSparseMoeBlock` has multiple experts as its
submodule. Using `set_z3_leaf_modules`, the DeepSpeed engine fetches
parameters of all the submodules when pre-fetching the parameters of
`MixtralSparseMoeBlock`.
  • Loading branch information
tohtana authored Jan 19, 2024
1 parent 5dea776 commit 96c5a87
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 11 deletions.
8 changes: 5 additions & 3 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import torch
from collections import OrderedDict
from deepspeed.utils import z3_leaf_module
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import _init_external_params
Expand Down Expand Up @@ -383,9 +384,10 @@ def _register_hooks_recursively(self, module, count=[0]):

#print(f"{module.__class__} : {module.id}")

for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
if not z3_leaf_module(module):
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)

@instrument_w_nvtx
def _pre_forward_module_hook(module, *args):
Expand Down
19 changes: 11 additions & 8 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Deque, Set

from deepspeed import comm as dist
from deepspeed.utils import z3_leaf_module
from deepspeed.utils.logging import logger
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import *
Expand Down Expand Up @@ -188,7 +189,7 @@ def record_parameters(self, sub_module: Module) -> None:
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id):
for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id):
self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id))

def construct_parameter_trace_from_module_trace(self):
Expand Down Expand Up @@ -261,14 +262,14 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
"""
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
"inflight": [p.ds_id for p in self.__inflight_param_registry],
}))

params_to_fetch = frozenset(iter_params(current_submodule))
params_to_fetch = frozenset(iter_params(current_submodule, recurse=z3_leaf_module(current_submodule)))
fetch_numel = sum(
[p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
if fetch_numel > 0:
Expand Down Expand Up @@ -390,8 +391,8 @@ def release_sub_module(self, submodule: Module, backward: bool) -> None:
"""release the parameters of a sub module, assuming they meet conditions to
be released."""
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule)))
for param in iter_params(submodule):
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param, backward)
Expand Down Expand Up @@ -473,7 +474,9 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set
if not self.is_complete_trace():
raise RuntimeError("expected trace to be complete")

params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) if not p.ds_persist)
params_to_release = set(
p.ds_id for p in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release))
if not p.ds_persist)

# Problem: When prefetcher scans the param trace, it skips AVAILABLE params.
# This creates issues if those params are released before the skipped uses:
Expand All @@ -482,7 +485,7 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set
# diverges from the trace.
# Solution: Don't release params whose reuse was skipped by prefetch. This is
# possible because we detect such skips during prefetch and mark those params.
for param in iter_params(submodule_to_release):
for param in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release)):
if self.__most_recent_step_id_param_fetched_for[param] > step_id:
params_to_release.discard(param.ds_id)

Expand All @@ -493,7 +496,7 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set
for module in self.__submodule_order[step_id:]:
if params_traversed >= self.__max_reuse_dist_in_numel:
break
for param in iter_params(module):
for param in iter_params(module, recurse=z3_leaf_module(submodule_to_release)):
params_to_release.discard(param.ds_id)
params_traversed += param.ds_numel

Expand Down
1 change: 1 addition & 0 deletions deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, z3_leaf_module
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
48 changes: 48 additions & 0 deletions deepspeed/utils/z3_leaf_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from typing import List, Type


def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], flag: bool) -> None:
assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \
f'leaf_module_classes must be a list of types, got {leaf_module_classes}'

def _set_z3_leaf_flag(model: torch.nn.Module):
if model.__class__ in leaf_module_classes:
model._z3_leaf = flag

model.apply(_set_z3_leaf_flag)


def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None:
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
"""
_do_set_z3_leaf_modules(model, leaf_module_classes, True)


def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None:
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
"""
_do_set_z3_leaf_modules(model, leaf_module_classes, False)


def z3_leaf_module(model: torch.nn.Module) -> bool:
"""Returns whether a module in `model` has been flagged as a 'leaf' module.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
"""
return hasattr(model, '_z3_leaf') and model._z3_leaf
90 changes: 90 additions & 0 deletions tests/unit/runtime/zero/test_zero_leaf_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed.comm as dist
import torch

from unit.common import DistributedTest
from unit.simple_model import random_dataloader

import deepspeed
from deepspeed.utils import set_z3_leaf_modules, z3_leaf_module


class MyModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.linears = torch.nn.ModuleList(
[torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
torch.nn.Linear(hidden_dim, hidden_dim, bias=False)])
self.act = torch.nn.ReLU()
self.cel = torch.nn.CrossEntropyLoss()
self.counter = 0

def forward(self, x, y):
# This fails without setting this module as a leaf module.
# See the comment in `set_z3_leaf_modules()`.
x = self.linears[self.counter % len(self.linears)](x)
x = self.act(x)
loss = self.cel(x, y)
self.counter += 1
return x, loss


def run_model(model, config_dict, hidden_dim, dtype):
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=10,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)
dist.barrier()
for batch in data_loader:
loss = model(batch[0], batch[1])
loss = loss[1]
model.backward(loss)
model.step()

# Needed in ZeRO 3. Not doing so can give memory leak
model.destroy()


class TestSetZ3LeafModule(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True

def test_set_z3_leaf_modules(self):
hidden_dim = 128

# `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 3,
"stage3_prefetch_bucket_size": hidden_dim**2,
"stage3_param_persistence_threshold": 0,
"stage3_max_reuse_distance": 0,
}
}

model = MyModel(hidden_dim)

assert not z3_leaf_module(model)
set_z3_leaf_modules(model, [MyModel])
assert z3_leaf_module(model)

run_model(model, config_dict, hidden_dim, torch.float16)

0 comments on commit 96c5a87

Please sign in to comment.