Skip to content

Commit

Permalink
Be able to set custom accelerator, precision and dtype for MPS accele…
Browse files Browse the repository at this point in the history
…rators (Apple M1 silicon) (Lightning-AI#94)
  • Loading branch information
agmo1993 authored May 29, 2023
1 parent 513c393 commit b1101fc
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions finetune_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
from lightning.fabric.strategies import DeepSpeedStrategy
from lightning.fabric.accelerators.mps import MPSAccelerator

from generate import generate
from lit_parrot.adapter import Parrot, Config, mark_only_adapter_as_trainable, adapter_state_from_state_dict
Expand Down Expand Up @@ -44,14 +45,14 @@ def main(
data_dir: Path = Path("data/alpaca"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/adapter/alpaca"),
precision = "bf16-mixed",
):
check_valid_checkpoint_dir(checkpoint_dir)

fabric = L.Fabric(
accelerator="cuda",
devices=devices,
strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"),
precision="bf16-mixed",
precision=precision,
)
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
Expand All @@ -63,7 +64,7 @@ def main(

config = Config.from_name(name=checkpoint_dir.name, block_size=max_seq_length)

with EmptyInitOnDevice(device=fabric.device, dtype=torch.bfloat16):
with EmptyInitOnDevice(device=fabric.device, dtype=torch.float32 if fabric._precision.precision == "32-true" else torch.bfloat16):
model = Parrot(config)
with lazy_load(checkpoint_dir / "lit_model.pth") as checkpoint:
model.load_state_dict(checkpoint, strict=False)
Expand Down Expand Up @@ -184,7 +185,11 @@ def pad_right(x, pad_id):

x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))

if isinstance(fabric.accelerator, MPSAccelerator):
x, y = fabric.to_device((x, y))
else:
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))

return x, y

Expand Down

0 comments on commit b1101fc

Please sign in to comment.