Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Nov 17, 2023
1 parent 9391759 commit 5c089c6
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 53 deletions.
5 changes: 1 addition & 4 deletions lit_gpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ def _init_weights(self, module: nn.Module) -> None:

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"lm_head.weight": "lm_head.linear.weight",
"lm_head.bias": "lm_head.linear.bias"
}
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

Expand Down
5 changes: 1 addition & 4 deletions lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,7 @@ def _init_weights(self, module: nn.Module) -> None:

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"lm_head.weight": "lm_head.linear.weight",
"lm_head.bias": "lm_head.linear.bias"
}
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from unittest.mock import Mock

import torch
from lightning import Fabric

from conftest import RunIf
from lightning import Fabric


def test_config_identical():
Expand Down
18 changes: 14 additions & 4 deletions tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import sys
from contextlib import redirect_stdout
from io import StringIO
from pathlib import Path
from unittest.mock import Mock

import pytest
import torch
from conftest import RunIf
from lightning import Fabric

from conftest import RunIf
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

import lit_gpt.config as config_module


def test_config_identical():
Expand Down Expand Up @@ -121,14 +129,16 @@ def test_adapter_v2_gpt_init_weights():
assert (param == 0).all()


def test_base_model_can_be_adapter_v2_loaded():
@pytest.mark.parametrize("name", [c["name"] for c in config_module.configs])
def test_base_model_can_be_adapter_v2_loaded(name):
from lit_gpt.adapter_v2 import GPT as AdapterV2GPT
from lit_gpt.adapter_v2 import adapter_filter
from lit_gpt.model import GPT as BaseGPT

base_model = BaseGPT.from_name("pythia-70m", bias=True, n_layer=2)
kwargs = {"n_layer": 2, "n_head": 8, "n_embd": 16, "padded_vocab_size": 32}
base_model = BaseGPT.from_name(name, **kwargs)
base_model_state_dict = base_model.state_dict()
lora_model = AdapterV2GPT.from_name("pythia-70m", bias=True, n_layer=2, adapter_start_layer=0)
lora_model = AdapterV2GPT.from_name(name, **kwargs, adapter_start_layer=0)
keys = lora_model.load_state_dict(base_model_state_dict, strict=False)
assert not keys.unexpected_keys
for k in keys.missing_keys:
Expand Down
22 changes: 13 additions & 9 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ def test_convert_lit_checkpoint(tmp_path):

@torch.inference_mode()
def test_against_falcon_40b():
from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_falcon as copy_to_theirs
from transformers.models.falcon.configuration_falcon import FalconConfig
from transformers.models.falcon.modeling_falcon import FalconForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_falcon as copy_to_theirs

ours_config = Config.from_name("falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32)
theirs_config = FalconConfig(
vocab_size=ours_config.padded_vocab_size,
Expand Down Expand Up @@ -71,9 +72,10 @@ def test_against_falcon_40b():

@torch.inference_mode()
def test_against_original_gpt_neox():
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_gpt_neox as copy_to_theirs
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

ours_config = Config(block_size=64, vocab_size=100, n_layer=4, n_head=8, n_embd=16)
assert ours_config.padded_vocab_size == 512
Expand Down Expand Up @@ -114,11 +116,12 @@ def test_against_original_gpt_neox():
"ours_kwargs", [{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf"}]
)
def test_against_hf_llama2(ours_kwargs):
from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_llama
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_llama

ours_config = Config.from_name(
padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs
)
Expand Down Expand Up @@ -152,11 +155,12 @@ def test_against_hf_llama2(ours_kwargs):

@torch.inference_mode()
def test_against_original_open_llama_3b():
from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_llama
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_llama

ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)
T = 5
theirs_config = LlamaConfig(
Expand Down Expand Up @@ -190,11 +194,11 @@ def test_against_hf_phi():
if not file_path.is_file():
urlretrieve(url=url, filename=file_path)

from original_phi_1_5 import MixFormerSequentialConfig, MixFormerSequentialForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_phi

from original_phi_1_5 import MixFormerSequentialConfig, MixFormerSequentialForCausalLM

ours_config = Config.from_name(
"phi-1_5", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
)
Expand Down
1 change: 0 additions & 1 deletion tests/test_gptq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import lightning as L
import pytest
import torch

from conftest import RunIf


Expand Down
29 changes: 15 additions & 14 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import sys
from contextlib import redirect_stdout
from io import StringIO
from itertools import product
from pathlib import Path
from unittest.mock import Mock

import pytest
import torch
from conftest import RunIf
from lightning import Fabric

from conftest import RunIf
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

import lit_gpt.config as config_module


def test_lora_layer_replacement():
Expand Down Expand Up @@ -305,9 +312,10 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap
],
)
def test_lora_qkv_linear_compare_conv1d(n_head, enable_lora):
from lit_gpt.lora import LoRAQKVLinear
from torch.nn import functional as F

from lit_gpt.lora import LoRAQKVLinear

C = 12
layer = LoRAQKVLinear(C, 3 * C, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora)
x = torch.randn((1, 1, C))
Expand Down Expand Up @@ -431,24 +439,17 @@ def test_lora_gpt_init_weights():
assert (param == 0).all()


def test_base_model_can_be_lora_loaded():
@pytest.mark.parametrize("name", [c["name"] for c in config_module.configs])
def test_base_model_can_be_lora_loaded(name):
from lit_gpt.lora import GPT as LoRAGPT
from lit_gpt.lora import lora_filter
from lit_gpt.model import GPT as BaseGPT

base_model = BaseGPT.from_name("pythia-70m", bias=True, n_layer=2)
kwargs = {"n_layer": 2, "n_head": 8, "n_embd": 16, "padded_vocab_size": 32}
base_model = BaseGPT.from_name(name, **kwargs)
base_model_state_dict = base_model.state_dict()
lora_model = LoRAGPT.from_name(
"pythia-70m",
bias=True,
n_layer=2,
r=1,
to_query=True,
to_key=True,
to_value=True,
to_projection=True,
to_mlp=True,
to_head=True,
name, **kwargs, r=1, to_query=True, to_key=True, to_value=True, to_projection=True, to_mlp=True, to_head=True
)
keys = lora_model.load_state_dict(base_model_state_dict, strict=False)
assert not keys.unexpected_keys
Expand Down
30 changes: 18 additions & 12 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import pytest
import torch
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_2

from conftest import RunIf
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_2

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -38,9 +37,10 @@
],
)
def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residual, device, dtype) -> None:
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_gpt_neox
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

torch.set_default_dtype(dtype)

Expand Down Expand Up @@ -110,9 +110,10 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua
],
)
def test_against_hf_falcon(kwargs, device, dtype):
from transformers.models.falcon import FalconConfig, FalconForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_falcon
from transformers.models.falcon import FalconConfig, FalconForCausalLM

torch.set_default_dtype(dtype)

Expand Down Expand Up @@ -160,11 +161,12 @@ def test_against_hf_falcon(kwargs, device, dtype):
],
)
def test_against_original_open_llama_3b(device, dtype):
from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama

torch.set_default_dtype(dtype)

ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)
Expand Down Expand Up @@ -215,11 +217,12 @@ def test_against_original_open_llama_3b(device, dtype):
],
)
def test_against_hf_llama2(ours_kwargs, device, dtype):
from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama

torch.set_default_dtype(dtype)

ours_config = Config.from_name(
Expand Down Expand Up @@ -327,11 +330,12 @@ def test_against_hf_phi(device, dtype):
],
)
def test_against_hf_mistral(device, dtype):
from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama
from transformers.models.mistral.configuration_mistral import MistralConfig
from transformers.models.mistral.modeling_mistral import MistralForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama

torch.set_default_dtype(dtype)

ours_config = Config.from_name(
Expand Down Expand Up @@ -456,9 +460,10 @@ def test_model_kv_cache_amp():
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
@torch.inference_mode()
def test_sdpa_choice(config):
from lit_gpt import GPT
from torch.backends.cuda import SDPBackend

from lit_gpt import GPT

torch.set_default_dtype(torch.float16)

def assert_sdpa_uses_flash(original_fn, q, k, v, mask):
Expand Down Expand Up @@ -500,9 +505,10 @@ def assert_sdpa_uses_flash(original_fn, q, k, v, mask):
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
@torch.inference_mode()
def test_sdpa_choice_kv_cache(config):
from lit_gpt import GPT
from torch.backends.cuda import SDPBackend

from lit_gpt import GPT

torch.set_default_dtype(torch.float16)

def assert_sdpa_uses_flash(original_fn, q, k, v, mask):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

@torch.inference_mode()
def test_rope():
from lit_gpt.model import apply_rope, build_rope_cache
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding, apply_rotary_pos_emb

from lit_gpt.model import apply_rope, build_rope_cache

bs, seq_len, n_head, n_embed = 1, 6, 2, 8
head_size = n_embed // n_head
x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import pytest
import torch
import torch.nn.functional as F
from lightning import Fabric

from conftest import RunIf
from lightning import Fabric


def test_find_multiple():
Expand Down Expand Up @@ -144,6 +143,7 @@ def test_num_parameters():
@pytest.mark.skip("To be fixed")
def test_num_parameters_bitsandbytes(mode):
from lightning.fabric.plugins import BitsandbytesPrecision

from lit_gpt import GPT
from lit_gpt.utils import num_parameters

Expand Down

0 comments on commit 5c089c6

Please sign in to comment.