Skip to content

Commit

Permalink
Fix initialization of RoPE using FSDP (Lightning-AI#887)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
diegodoimo and carmocca authored Jan 18, 2024
1 parent af4ce8d commit 5a301d3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 5 deletions.
3 changes: 2 additions & 1 deletion generate/sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int
submodule.attn.kv_cache = submodule.attn.build_kv_cache(1, max_seq_length, model.cos.size(-1), target_device)
# rebuild odd ends
with root:
# the rope cache which is on meta device
model.max_seq_length = max_seq_length
# the rope cache which is on meta device
model.cos, model.sin = model.rope_cache()
# the mask cache which cannot be created with `set_kv_cache` because that will set it for all layers
model.mask_cache = build_mask_cache(max_seq_length)
# and everything that is not a block in the root
Expand Down
2 changes: 2 additions & 0 deletions generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def main(
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# the rope cache which is on meta device
model.cos, model.sin = model.rope_cache()
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
Expand Down
6 changes: 2 additions & 4 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,15 @@ def max_seq_length(self, value: int) -> None:
cos, sin = self.rope_cache()
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
# overrides
elif self.cos.device.type == "meta":
self.cos, self.sin = self.rope_cache()
# override
elif value != self.cos.size(0):
self.cos, self.sin = self.rope_cache(device=self.cos.device)
# the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
# if the kv cache is expected

def reset_parameters(self) -> None:
# Trigger resetting the rope-cache
self.max_seq_length = self.config.block_size
self.cos, self.sin = self.rope_cache()

def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
from conftest import RunIf
from lightning import Fabric
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_2

# support running without installing as a package
Expand Down Expand Up @@ -752,3 +753,24 @@ def assert_sdpa_backend(original_fn, q, k, v, mask):
)
with torch.backends.cuda.sdp_kernel(enable_flash=False):
model(x, input_pos)


@RunIf(min_cuda_gpus=2, standalone=True)
def test_rope_init_under_fsdp():
"""Check that the rope cache is properly intialized"""
from lit_gpt import GPT

fabric = Fabric(devices=2, strategy="fsdp", accelerator="cuda")
fabric.launch()

with fabric.init_module(empty_init=True):
model = GPT.from_name("pythia-14m", n_layer=1)
assert model.cos.device.type == "meta"
assert model.sin.device.type == "meta"

model = fabric.setup(model)
assert model.cos.device.type == "cuda"
assert model.sin.device.type == "cuda"
cos, sin = model.rope_cache(device=fabric.device)
torch.testing.assert_close(model.cos, cos)
torch.testing.assert_close(model.sin, sin)

0 comments on commit 5a301d3

Please sign in to comment.