Skip to content

Commit

Permalink
Use Fabric's lazy_load implementation (Lightning-AI#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 5, 2023
1 parent 3bd596c commit 38a6e45
Show file tree
Hide file tree
Showing 20 changed files with 56 additions and 247 deletions.
4 changes: 2 additions & 2 deletions chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def main(
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
with fabric.init_module(empty_init=True), quantization(quantize):
model = GPT(config)
with lazy_load(checkpoint_path) as checkpoint:
model.load_state_dict(checkpoint.get("model", checkpoint), strict=quantize is None)
checkpoint = lazy_load(checkpoint_path)
model.load_state_dict(checkpoint.get("model", checkpoint), strict=quantize is None)

model.eval()
model = fabric.setup_module(model)
Expand Down
4 changes: 2 additions & 2 deletions eval/lm_eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def __init__(
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

t0 = time.perf_counter()
with lazy_load(checkpoint_path) as checkpoint:
model.load_state_dict(checkpoint.get("model", checkpoint), strict=quantize is None)
checkpoint = lazy_load(checkpoint_path)
model.load_state_dict(checkpoint.get("model", checkpoint), strict=quantize is None)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
Expand Down
6 changes: 3 additions & 3 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=False):
model = GPT(config)
with lazy_load(checkpoint_path) as checkpoint:
# strict=False because missing keys due to adapter weights not contained in state dict
model.load_state_dict(checkpoint, strict=False)
checkpoint = lazy_load(checkpoint_path)
# strict=False because missing keys due to adapter weights not contained in state dict
model.load_state_dict(checkpoint, strict=False)

mark_only_adapter_as_trainable(model)

Expand Down
6 changes: 3 additions & 3 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=False):
model = GPT(config)
with lazy_load(checkpoint_path) as checkpoint:
# strict=False because missing keys due to adapter weights not contained in state dict
model.load_state_dict(checkpoint, strict=False)
checkpoint = lazy_load(checkpoint_path)
# strict=False because missing keys due to adapter weights not contained in state dict
model.load_state_dict(checkpoint, strict=False)

mark_only_adapter_v2_as_trainable(model)

Expand Down
7 changes: 4 additions & 3 deletions generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ def main(
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

t0 = time.perf_counter()
with lazy_load(checkpoint_path) as checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint, strict=quantize is None)
checkpoint = lazy_load(checkpoint_path)
adapter_checkpoint = lazy_load(adapter_path)
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint, strict=quantize is None)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
Expand Down
7 changes: 4 additions & 3 deletions generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ def main(
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

t0 = time.perf_counter()
with lazy_load(checkpoint_path) as checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint, strict=quantize is None)
checkpoint = lazy_load(checkpoint_path)
adapter_checkpoint = lazy_load(adapter_path)
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint, strict=quantize is None)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
Expand Down
7 changes: 4 additions & 3 deletions generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def main(
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

t0 = time.perf_counter()
with lazy_load(checkpoint_path) as checkpoint, lazy_load(lora_path) as lora_checkpoint:
checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(checkpoint, strict=quantize is None)
checkpoint = lazy_load(checkpoint_path)
lora_checkpoint = lazy_load(lora_path)
checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(checkpoint, strict=quantize is None)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
Expand Down
152 changes: 2 additions & 150 deletions lit_gpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""Utility functions for training and inference."""
import os
import pickle
import sys
import warnings
from contextlib import contextmanager
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Mapping, Optional, TypeVar, Union
Expand All @@ -13,7 +10,7 @@
import torch.nn as nn
import torch.utils._device
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from torch.serialization import normalize_storage_type


Expand Down Expand Up @@ -83,151 +80,6 @@ def __init__(self, *args, **kwargs):
torch.nn.Linear = torch_linear_cls


# this is taken from torchhacks https://github.com/lernapparat/torchhacks


class NotYetLoadedTensor:
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
self.metatensor = metatensor
self.archiveinfo = archiveinfo
self.storageinfo = storageinfo
self.rebuild_args = rebuild_args

@classmethod
def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None):
ret = func(*args)
if isinstance(ret, NotYetLoadedTensor):
old_lt = ret._load_tensor

def _load_tensor():
t = old_lt()
return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state)

ret._load_tensor = _load_tensor
return ret
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)

@classmethod
def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None):
if isinstance(data, NotYetLoadedTensor):
old_lt = data._load_tensor

def _load_tensor():
t = old_lt()
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)

data._load_tensor = _load_tensor
return data
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)

@classmethod
def rebuild_tensor_v2(
cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None
):
rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata)
metatensor = torch._utils._rebuild_tensor_v2(
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata
)
storageinfo = storage.archiveinfo
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)

def _load_tensor(self):
name, storage_cls, fn, device, size = self.storageinfo
dtype = self.metatensor.dtype

uts = (
self.archiveinfo.zipfile_context.zf.get_storage_from_record(
f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage
)
._typed_storage()
._untyped_storage
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True)
return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args]
return func(*loaded_args, **kwargs)
# gc.collect would be costly here, maybe do it optionally

def __getattr__(self, name):
# properties
## TODO: device, is_...??
## TODO: mH, mT, H, T, data, imag, real
## name ???
if name in {
"dtype",
"grad",
"grad_fn",
"is_meta",
"layout",
"names",
"ndim",
"output_nr",
"requires_grad",
"retains_grad",
"shape",
"volatile",
}:
return getattr(self.metatensor, name)
if name in {"size"}:
return getattr(self.metatensor, name)
# materializing with contiguous is needed for quantization
if name in {"contiguous"}:
return getattr(self._load_tensor(), name)

raise AttributeError(f"{type(self)} does not have {name}")

def __repr__(self):
return f"NotYetLoadedTensor({repr(self.metatensor)})"


class LazyLoadingUnpickler(pickle.Unpickler):
def __init__(self, file, zipfile_context):
super().__init__(file)
self.zipfile_context = zipfile_context

def find_class(self, module, name):
res = super().find_class(module, name)
if module == "torch._utils" and name == "_rebuild_tensor_v2":
return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)
if module == "torch._tensor" and name == "_rebuild_from_type_v2":
return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)
if module == "torch._utils" and name == "_rebuild_parameter":
return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)
return res

def persistent_load(self, pid):
name, cls, fn, device, size = pid
with warnings.catch_warnings():
warnings.simplefilter("ignore")
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
s.archiveinfo = pid
return s


class lazy_load:
def __init__(self, path: Union[Path, str]) -> None:
if not os.path.isfile(path):
raise FileNotFoundError(f"Path {str(path)!r} does not exist or is not a file.")
self.zf = torch._C.PyTorchFileReader(str(path))
with BytesIO(self.zf.get_record("data.pkl")) as pkl:
mup = LazyLoadingUnpickler(pkl, self)
self.sd = mup.load()

def __enter__(self):
return self.sd

def __exit__(self, exc_type, exc_val, exc_tb):
del self.zf # I don't think there is a way to force closing...
self.zf = None


def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
files = {
"lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
Expand Down Expand Up @@ -484,6 +336,6 @@ def load_checkpoint(fabric, model, checkpoint_path: Path, strict: bool = True) -
if isinstance(fabric.strategy, FSDPStrategy):
fabric.load_raw(checkpoint_path, model, strict=strict)
else:
state_dict = _lazy_load(checkpoint_path)
state_dict = lazy_load(checkpoint_path)
state_dict = state_dict.get("model", state_dict)
model.load_state_dict(state_dict, strict=strict)
4 changes: 2 additions & 2 deletions quantize/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ def main(
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
with lazy_load(checkpoint_path) as checkpoint:
model.load_state_dict(checkpoint)
checkpoint = lazy_load(checkpoint_path)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.perf_counter() - t0:.02f} seconds.")

model.eval()
Expand Down
3 changes: 2 additions & 1 deletion scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt import Config
from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load
from lit_gpt.utils import incremental_save, lazy_load


def copy_weights_gpt_neox(
Expand Down
13 changes: 7 additions & 6 deletions scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from typing import Dict, Optional, Tuple, Union

import torch
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt import Config
from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load
from lit_gpt.utils import incremental_save, lazy_load
from scripts.convert_hf_checkpoint import layer_template, load_param


Expand Down Expand Up @@ -220,11 +221,11 @@ def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path
# initialize a new empty state dict to hold our new weights
sd = {}
with incremental_save(output_path) as saver:
with lazy_load(checkpoint_path) as lit_weights:
lit_weights = lit_weights.get("model", lit_weights)
check_conversion_supported(lit_weights)
copy_fn(sd, lit_weights, saver=saver)
gc.collect()
lit_weights = lazy_load(checkpoint_path)
lit_weights = lit_weights.get("model", lit_weights)
check_conversion_supported(lit_weights)
copy_fn(sd, lit_weights, saver=saver)
gc.collect()
saver.save(sd)


Expand Down
7 changes: 4 additions & 3 deletions scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ def merge_lora(
with fabric.init_module(empty_init=True):
model = GPT(config)
checkpoint_path = checkpoint_dir / "lit_model.pth"
with lazy_load(checkpoint_path) as checkpoint, lazy_load(lora_path) as lora_checkpoint:
checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(checkpoint)
checkpoint = lazy_load(checkpoint_path)
lora_checkpoint = lazy_load(lora_path)
checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(checkpoint)

merge_lora_weights(model)

Expand Down
7 changes: 2 additions & 5 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,8 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)
monkeypatch.setitem(name_to_config, "tmp", model_config)

load_mock = Mock()
load_mock.return_value = load_mock
load_mock.__enter__ = Mock()
load_mock.__exit__ = Mock()
monkeypatch.setattr(module, "lazy_load", load_mock)
monkeypatch.setattr(module, "lazy_load", Mock())
monkeypatch.setattr(module.GPT, "load_state_dict", Mock())

tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
Expand Down
7 changes: 2 additions & 5 deletions tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,8 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)
monkeypatch.setitem(name_to_config, "tmp", model_config)

load_mock = Mock()
load_mock.return_value = load_mock
load_mock.__enter__ = Mock()
load_mock.__exit__ = Mock()
monkeypatch.setattr(module, "lazy_load", load_mock)
monkeypatch.setattr(module, "lazy_load", Mock())
monkeypatch.setattr(module.GPT, "load_state_dict", Mock())

tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
Expand Down
7 changes: 2 additions & 5 deletions tests/test_generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like):
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
config_path.write_text(json.dumps(config))

load_mock = Mock()
load_mock.return_value = load_mock
load_mock.__enter__ = Mock()
load_mock.__exit__ = Mock()
monkeypatch.setattr(generate, "lazy_load", load_mock)
monkeypatch.setattr(generate, "lazy_load", Mock())
monkeypatch.setattr(generate.GPT, "load_state_dict", Mock())
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
tokenizer_mock.return_value.decode.return_value = "### Response:foo bar baz"
Expand Down
7 changes: 2 additions & 5 deletions tests/test_generate_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
}
config_path.write_text(json.dumps(config))

load_mock = Mock()
load_mock.return_value = load_mock
load_mock.__enter__ = Mock()
load_mock.__exit__ = Mock()
monkeypatch.setattr(generate, "lazy_load", load_mock)
monkeypatch.setattr(generate, "lazy_load", Mock())
monkeypatch.setattr(generate.GPT, "load_state_dict", Mock())
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
tokenizer_mock.return_value.decode.return_value = "### Response:foo bar baz"
Expand Down
Loading

0 comments on commit 38a6e45

Please sign in to comment.