forked from microsoft/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add API to set a module as a leaf node when recursively setting Z3 ho…
…oks (microsoft#4966) ZeRO3 does not work with MoE models because the order of executing modules can change at every forward/backward pass (microsoft#4094, microsoft#4808). This PR adds an API to stop breaking down a module for parameter fetching. The following shows an example of the usage: ```python import torch import deepspeed import deepspeed.comm as dist from transformers.deepspeed import HfDeepSpeedConfig from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock model_id = "mistralai/Mixtral-8x7B-v0.1" ds_config = { "bf16": { "enabled": True, }, "zero_optimization": { "stage": 3, }, "train_micro_batch_size_per_gpu": 1, } hfdsc = HfDeepSpeedConfig(ds_config) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) model.eval() ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] ds_engine.module.eval() model = ds_engine.module inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to("cuda") outputs = model.generate(inputs, max_new_tokens=200) output_str = tokenizer.decode(outputs[0]) if dist.get_rank() == 0: print(f"output: {output_str}") ``` By passing names of modules to `set_z3_leaf_modules`, DeepSpeed engine stops breaking down the module. In this example, `MixtralSparseMoeBlock` has multiple experts as its submodule. Using `set_z3_leaf_modules`, the DeepSpeed engine fetches parameters of all the submodules when pre-fetching the parameters of `MixtralSparseMoeBlock`.
- Loading branch information
Showing
5 changed files
with
155 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
from typing import List, Type | ||
|
||
|
||
def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], flag: bool) -> None: | ||
assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \ | ||
f'leaf_module_classes must be a list of types, got {leaf_module_classes}' | ||
|
||
def _set_z3_leaf_flag(model: torch.nn.Module): | ||
if model.__class__ in leaf_module_classes: | ||
model._z3_leaf = flag | ||
|
||
model.apply(_set_z3_leaf_flag) | ||
|
||
|
||
def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None: | ||
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. | ||
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module. | ||
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks. | ||
Args: | ||
model (torch.nn.Module): The model to which the leaf module flag will be applied. | ||
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. | ||
""" | ||
_do_set_z3_leaf_modules(model, leaf_module_classes, True) | ||
|
||
|
||
def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None: | ||
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. | ||
See `set_z3_leaf_modules` for more details. | ||
Args: | ||
model (torch.nn.Module): The model to which the leaf module flag will be applied. | ||
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. | ||
""" | ||
_do_set_z3_leaf_modules(model, leaf_module_classes, False) | ||
|
||
|
||
def z3_leaf_module(model: torch.nn.Module) -> bool: | ||
"""Returns whether a module in `model` has been flagged as a 'leaf' module. | ||
See `set_z3_leaf_modules` for more details. | ||
Args: | ||
model (torch.nn.Module): The model to which the leaf module flag will be applied. | ||
""" | ||
return hasattr(model, '_z3_leaf') and model._z3_leaf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import deepspeed.comm as dist | ||
import torch | ||
|
||
from unit.common import DistributedTest | ||
from unit.simple_model import random_dataloader | ||
|
||
import deepspeed | ||
from deepspeed.utils import set_z3_leaf_modules, z3_leaf_module | ||
|
||
|
||
class MyModel(torch.nn.Module): | ||
|
||
def __init__(self, hidden_dim): | ||
super(MyModel, self).__init__() | ||
self.linears = torch.nn.ModuleList( | ||
[torch.nn.Linear(hidden_dim, hidden_dim, bias=False), | ||
torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) | ||
self.act = torch.nn.ReLU() | ||
self.cel = torch.nn.CrossEntropyLoss() | ||
self.counter = 0 | ||
|
||
def forward(self, x, y): | ||
# This fails without setting this module as a leaf module. | ||
# See the comment in `set_z3_leaf_modules()`. | ||
x = self.linears[self.counter % len(self.linears)](x) | ||
x = self.act(x) | ||
loss = self.cel(x, y) | ||
self.counter += 1 | ||
return x, loss | ||
|
||
|
||
def run_model(model, config_dict, hidden_dim, dtype): | ||
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) | ||
data_loader = random_dataloader(model=model, | ||
total_samples=10, | ||
hidden_dim=hidden_dim, | ||
device=model.device, | ||
dtype=dtype) | ||
dist.barrier() | ||
for batch in data_loader: | ||
loss = model(batch[0], batch[1]) | ||
loss = loss[1] | ||
model.backward(loss) | ||
model.step() | ||
|
||
# Needed in ZeRO 3. Not doing so can give memory leak | ||
model.destroy() | ||
|
||
|
||
class TestSetZ3LeafModule(DistributedTest): | ||
# Need multiple gpus to test possible hanging | ||
world_size = 2 | ||
reuse_dist_env = True | ||
|
||
def test_set_z3_leaf_modules(self): | ||
hidden_dim = 128 | ||
|
||
# `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module | ||
config_dict = { | ||
"train_micro_batch_size_per_gpu": 1, | ||
"steps_per_print": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-6 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": True | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"stage3_prefetch_bucket_size": hidden_dim**2, | ||
"stage3_param_persistence_threshold": 0, | ||
"stage3_max_reuse_distance": 0, | ||
} | ||
} | ||
|
||
model = MyModel(hidden_dim) | ||
|
||
assert not z3_leaf_module(model) | ||
set_z3_leaf_modules(model, [MyModel]) | ||
assert z3_leaf_module(model) | ||
|
||
run_model(model, config_dict, hidden_dim, torch.float16) |