Skip to content

Commit

Permalink
lit_config.json -> model_config.yaml (Lightning-AI#1096)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and awaelchli committed Mar 15, 2024
1 parent 15ccab4 commit 674f315
Show file tree
Hide file tree
Showing 39 changed files with 102 additions and 90 deletions.
2 changes: 1 addition & 1 deletion eval/lm_eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def run_eval_harness(
check_valid_checkpoint_dir(checkpoint_dir)
tokenizer = Tokenizer(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"

Expand Down
2 changes: 1 addition & 1 deletion litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def main(

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"

Expand Down
22 changes: 11 additions & 11 deletions litgpt/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
import yaml
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -107,29 +107,29 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(**conf_dict)

@classmethod
def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
with open(path, encoding="utf-8") as fp:
json_kwargs = json.load(fp)
json_kwargs.update(kwargs)
return cls(**json_kwargs)
file_kwargs = yaml.safe_load(fp)
file_kwargs.update(kwargs)
return cls(**file_kwargs)

@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
"""Automatically load `lit_config.json` and if it doesn't exist - a matching config from `litgpt/config.py`."""
if (config_path := path / "lit_config.json").is_file():
return cls.from_json(config_path, **kwargs)
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
if (config_path := path / "model_config.yaml").is_file():
return cls.from_file(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(f"For {str(path)!r} neither 'lit_config.json' nor matching config exists.")
raise FileNotFoundError(f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists.")

@property
def mlp_class(self) -> Type:
# `self.mlp_class_name` cannot be the type to keep the config json serializable
# `self.mlp_class_name` cannot be the type to keep the config serializable
return getattr(litgpt.model, self.mlp_class_name)

@property
def norm_class(self) -> Type:
# `self.norm_class_name` cannot be the type to keep the config json serializable
# `self.norm_class_name` cannot be the type to keep the config serializable
if self.norm_class_name == "RMSNorm":
from functools import partial

Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main(

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"

Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main(

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"

Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def main(

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"

Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = finetuned_path

Expand Down
4 changes: 2 additions & 2 deletions litgpt/generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def main(

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(
checkpoint_dir / "lit_config.json",
config = Config.from_file(
checkpoint_dir / "model_config.yaml",
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def main(

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"

Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file
Expand Down
13 changes: 7 additions & 6 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import importlib
import json
import re
from abc import abstractmethod
from json import dumps
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Type, Tuple, Union

import yaml

from litgpt.config import Config

if TYPE_CHECKING:
Expand Down Expand Up @@ -338,13 +339,13 @@ def save_prompt_style(style: Union[str, PromptStyle], checkpoint_dir: Path) -> N
cls = type(style)
# Allow saving the full module path for user-defined prompt classes
config = {"class_path": f"{cls.__module__}.{cls.__name__}"}
with open(checkpoint_dir / "prompt_style.json", "w") as file:
json.dump(config, file)
with open(checkpoint_dir / "prompt_style.yaml", "w") as file:
yaml.dump(config, file)


def load_prompt_style(checkpoint_dir: Path) -> PromptStyle:
with open(checkpoint_dir / "prompt_style.json", "r") as file:
config = json.load(file)
with open(checkpoint_dir / "prompt_style.yaml", "r") as file:
config = yaml.safe_load(file)
# Support loading the full module path for user-defined prompt classes
full_module_path, cls_name = config["class_path"].rsplit(".", 1)
module = importlib.import_module(full_module_path)
Expand All @@ -353,4 +354,4 @@ def load_prompt_style(checkpoint_dir: Path) -> PromptStyle:


def has_prompt_style(checkpoint_dir: Path) -> bool:
return (checkpoint_dir / "prompt_style.json").is_file()
return (checkpoint_dir / "prompt_style.yaml").is_file()
2 changes: 1 addition & 1 deletion litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None:
@torch.inference_mode()
def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
"""Convert a LitGPT trained checkpoint into a Hugging Face Transformers checkpoint."""
config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_file(checkpoint_dir / "model_config.yaml")

output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "model.pth"
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def merge_lora(
precision = precision if precision is not None else lora_precision

fabric = L.Fabric(devices=1, precision=precision)
config = Config.from_json(checkpoint_dir / "lit_config.json", **lora_params)
config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params)

with fabric.init_module(empty_init=True):
model = GPT(config)
Expand Down
10 changes: 5 additions & 5 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

"""Utility functions for training and inference."""
import json
import math
import pickle
import shutil
Expand All @@ -15,6 +14,7 @@
import torch
import torch.nn as nn
import torch.utils._device
import yaml
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
Expand Down Expand Up @@ -50,7 +50,7 @@ def check_valid_checkpoint_dir(checkpoint_dir: Path, lora: bool = False) -> None
model_filename = "lit_model.pth.lora" if lora else "lit_model.pth"
files = {
model_filename: (checkpoint_dir / model_filename).is_file(),
"lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
"model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file()
or (checkpoint_dir / "tokenizer.model").is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
Expand Down Expand Up @@ -379,7 +379,7 @@ def __iter__(self) -> Self:
def copy_config_files(source_dir: Path, out_dir: Path) -> None:
"""Copies the specified configuration and tokenizer files into the output directory."""

config_files = ["generation_config.json", "lit_config.json"]
config_files = ["generation_config.json", "model_config.yaml"]
tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]

for file_name in config_files + tokenizer_files:
Expand Down Expand Up @@ -410,8 +410,8 @@ def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:

def save_config(config: "Config", checkpoint_dir: Path) -> None:
config_dict = asdict(config)
with open(checkpoint_dir / "lit_config.json", "w") as json_config:
json.dump(config_dict, json_config)
with open(checkpoint_dir / "model_config.yaml", "w") as fp:
yaml.dump(config_dict, fp)


def parse_devices(devices: Union[str, int]) -> int:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def fake_checkpoint_dir(tmp_path):
checkpoint_dir = tmp_path / "checkpoints" / "tmp"
checkpoint_dir.mkdir(parents=True)
(checkpoint_dir / "lit_model.pth").touch()
(checkpoint_dir / "lit_config.json").touch()
(checkpoint_dir / "model_config.yaml").touch()
(checkpoint_dir / "tokenizer.json").touch()
(checkpoint_dir / "tokenizer_config.json").touch()
return checkpoint_dir
Expand Down
4 changes: 2 additions & 2 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path)
for checkpoint_dir in checkpoint_dirs:
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {
"lit_model.pth",
"lit_config.json",
"model_config.yaml",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
"prompt_style.json",
"prompt_style.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa
for checkpoint_dir in checkpoint_dirs:
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {
"lit_model.pth",
"lit_config.json",
"model_config.yaml",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
"prompt_style.json",
"prompt_style.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand Down
6 changes: 3 additions & 3 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
import subprocess
import sys
from contextlib import redirect_stderr, redirect_stdout
Expand All @@ -11,6 +10,7 @@

import pytest
import torch
import yaml


@pytest.mark.parametrize(
Expand Down Expand Up @@ -86,9 +86,9 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
# these values will be iteratively provided for each `input()` call
mocked_input.side_effect = ["Hello", stop_iteration]

config_path = fake_checkpoint_dir / "lit_config.json"
config_path = fake_checkpoint_dir / "model_config.yaml"
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
config_path.write_text(json.dumps(config))
config_path.write_text(yaml.dump(config))

load_mock = Mock()
load_mock.return_value = load_mock
Expand Down
11 changes: 6 additions & 5 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import pytest
import yaml

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -65,7 +66,7 @@ def test_from_checkpoint(tmp_path):
from litgpt import Config

# 1. Neither `lit_config.py` nor matching config exists.
with pytest.raises(FileNotFoundError, match="neither 'lit_config.json' nor matching config exists"):
with pytest.raises(FileNotFoundError, match="neither 'model_config.yaml' nor matching config exists"):
Config.from_checkpoint(tmp_path / "non_existing_checkpoint")

# 2. If `lit_config.py` doesn't exists, but there is a matching config in `litgpt/config.py`.
Expand All @@ -76,17 +77,17 @@ def test_from_checkpoint(tmp_path):

# 3. If only `lit_config.py` exists.
config_data = {"name": "pythia-14m", "block_size": 24, "n_layer": 2}
with open(tmp_path / "lit_config.json", "w") as file:
json.dump(config_data, file)
with open(tmp_path / "model_config.yaml", "w") as file:
yaml.dump(config_data, file)
config = Config.from_checkpoint(tmp_path)
assert config.name == "pythia-14m"
assert config.block_size == 24
assert config.n_layer == 2

# 4. Both `lit_config.py` and a matching config exist, but `lit_config.py` supersedes matching config
(tmp_path / "pythia-14m").mkdir()
with open(tmp_path / "pythia-14m/lit_config.json", "w") as file:
json.dump(config_data, file)
with open(tmp_path / "pythia-14m/model_config.yaml", "w") as file:
yaml.dump(config_data, file)
config = Config.from_checkpoint(tmp_path / "pythia-14m")
assert config.name == "pythia-14m"
assert config.block_size == 24
Expand Down
4 changes: 2 additions & 2 deletions tests/test_convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def test_convert_hf_checkpoint(tmp_path):
convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name="pythia-14m")
load.assert_called_with(bin_file)

assert {p.name for p in tmp_path.glob("*")} == {"foo.bin", "lit_config.json", "lit_model.pth"}
assert {p.name for p in tmp_path.glob("*")} == {"foo.bin", "model_config.yaml", "lit_model.pth"}

# ensure that the config dict can be loaded
from litgpt import Config

config = Config.from_json(tmp_path / "lit_config.json")
config = Config.from_file(tmp_path / "model_config.yaml")
assert isinstance(config, Config)
8 changes: 5 additions & 3 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import pytest
import torch
import yaml

from conftest import RunIf

wd = Path(__file__).parent.parent.absolute()
Expand All @@ -21,14 +23,14 @@ def test_convert_lit_checkpoint(tmp_path):
ours_config = Config.from_name("Llama-2-7b-hf", block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128)
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
config_path = tmp_path / "lit_config.json"
config_path = tmp_path / "model_config.yaml"
torch.save(ours_model.state_dict(), checkpoint_path)
with open(config_path, "w") as fp:
json.dump(asdict(ours_config), fp)
yaml.dump(asdict(ours_config), fp)
output_dir = tmp_path / "out_dir"

convert_lit_checkpoint(checkpoint_path.parent, output_dir)
assert set(os.listdir(tmp_path)) == {"lit_model.pth", "lit_config.json", "out_dir"}
assert set(os.listdir(tmp_path)) == {"lit_model.pth", "model_config.yaml", "out_dir"}
assert os.path.isfile(output_dir / "model.pth")

# check checkpoint is unwrapped
Expand Down
2 changes: 1 addition & 1 deletion tests/test_convert_pretrained_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_convert_pretrained_checkpoint(tmp_path, fake_checkpoint_dir):
convert_pretrained_checkpoint(checkpoint_dir=fake_checkpoint_dir, output_dir=(tmp_path / "converted"))

assert set(os.listdir(tmp_path / "converted")) == {
"lit_model.pth", "lit_config.json", "tokenizer_config.json", "tokenizer.json"
"lit_model.pth", "model_config.yaml", "tokenizer_config.json", "tokenizer.json"
}
converted_checkpoint = torch.load(tmp_path / "converted" / "lit_model.pth")
assert list(converted_checkpoint.keys()) == ["some.module.weight", "some.other.module.weight"]
4 changes: 2 additions & 2 deletions tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
for checkpoint_dir in checkpoint_dirs:
assert set(os.listdir(out_dir / checkpoint_dir)) == {
"lit_model.pth",
"lit_config.json",
"model_config.yaml",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
"prompt_style.json",
"prompt_style.yaml",
}
assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file()

Expand Down
Loading

0 comments on commit 674f315

Please sign in to comment.