From 5c089c6268881868380d0b0bd871b3fcf91a18ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 17 Nov 2023 17:10:21 -0500 Subject: [PATCH] Add test for #735 (#748) --- lit_gpt/adapter_v2.py | 5 +---- lit_gpt/lora.py | 5 +---- tests/test_adapter.py | 3 +-- tests/test_adapter_v2.py | 18 +++++++++++++---- tests/test_convert_lit_checkpoint.py | 22 +++++++++++--------- tests/test_gptq.py | 1 - tests/test_lora.py | 29 ++++++++++++++------------- tests/test_model.py | 30 +++++++++++++++++----------- tests/test_rope.py | 3 ++- tests/test_utils.py | 4 ++-- 10 files changed, 67 insertions(+), 53 deletions(-) diff --git a/lit_gpt/adapter_v2.py b/lit_gpt/adapter_v2.py index 8389002a4b..caa6e555a3 100644 --- a/lit_gpt/adapter_v2.py +++ b/lit_gpt/adapter_v2.py @@ -89,10 +89,7 @@ def _init_weights(self, module: nn.Module) -> None: def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" - mapping = { - "lm_head.weight": "lm_head.linear.weight", - "lm_head.bias": "lm_head.linear.bias" - } + mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index f4fc208df6..2e41b701a0 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -514,10 +514,7 @@ def _init_weights(self, module: nn.Module) -> None: def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" - mapping = { - "lm_head.weight": "lm_head.linear.weight", - "lm_head.bias": "lm_head.linear.bias" - } + mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index c6894d5faf..024a032537 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -4,9 +4,8 @@ from unittest.mock import Mock import torch -from lightning import Fabric - from conftest import RunIf +from lightning import Fabric def test_config_identical(): diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index 2656ba1760..99e7016d54 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -1,11 +1,19 @@ +import sys from contextlib import redirect_stdout from io import StringIO +from pathlib import Path from unittest.mock import Mock +import pytest import torch +from conftest import RunIf from lightning import Fabric -from conftest import RunIf +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.config as config_module def test_config_identical(): @@ -121,14 +129,16 @@ def test_adapter_v2_gpt_init_weights(): assert (param == 0).all() -def test_base_model_can_be_adapter_v2_loaded(): +@pytest.mark.parametrize("name", [c["name"] for c in config_module.configs]) +def test_base_model_can_be_adapter_v2_loaded(name): from lit_gpt.adapter_v2 import GPT as AdapterV2GPT from lit_gpt.adapter_v2 import adapter_filter from lit_gpt.model import GPT as BaseGPT - base_model = BaseGPT.from_name("pythia-70m", bias=True, n_layer=2) + kwargs = {"n_layer": 2, "n_head": 8, "n_embd": 16, "padded_vocab_size": 32} + base_model = BaseGPT.from_name(name, **kwargs) base_model_state_dict = base_model.state_dict() - lora_model = AdapterV2GPT.from_name("pythia-70m", bias=True, n_layer=2, adapter_start_layer=0) + lora_model = AdapterV2GPT.from_name(name, **kwargs, adapter_start_layer=0) keys = lora_model.load_state_dict(base_model_state_dict, strict=False) assert not keys.unexpected_keys for k in keys.missing_keys: diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 6b776a7b7e..d57b5826d2 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -36,11 +36,12 @@ def test_convert_lit_checkpoint(tmp_path): @torch.inference_mode() def test_against_falcon_40b(): - from lit_gpt import GPT, Config - from scripts.convert_lit_checkpoint import copy_weights_falcon as copy_to_theirs from transformers.models.falcon.configuration_falcon import FalconConfig from transformers.models.falcon.modeling_falcon import FalconForCausalLM + from lit_gpt import GPT, Config + from scripts.convert_lit_checkpoint import copy_weights_falcon as copy_to_theirs + ours_config = Config.from_name("falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32) theirs_config = FalconConfig( vocab_size=ours_config.padded_vocab_size, @@ -71,9 +72,10 @@ def test_against_falcon_40b(): @torch.inference_mode() def test_against_original_gpt_neox(): + from transformers import GPTNeoXConfig, GPTNeoXForCausalLM + from lit_gpt import GPT, Config from scripts.convert_lit_checkpoint import copy_weights_gpt_neox as copy_to_theirs - from transformers import GPTNeoXConfig, GPTNeoXForCausalLM ours_config = Config(block_size=64, vocab_size=100, n_layer=4, n_head=8, n_embd=16) assert ours_config.padded_vocab_size == 512 @@ -114,11 +116,12 @@ def test_against_original_gpt_neox(): "ours_kwargs", [{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf"}] ) def test_against_hf_llama2(ours_kwargs): - from lit_gpt import GPT, Config - from scripts.convert_lit_checkpoint import copy_weights_llama from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM + from lit_gpt import GPT, Config + from scripts.convert_lit_checkpoint import copy_weights_llama + ours_config = Config.from_name( padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs ) @@ -152,11 +155,12 @@ def test_against_hf_llama2(ours_kwargs): @torch.inference_mode() def test_against_original_open_llama_3b(): - from lit_gpt import GPT, Config - from scripts.convert_lit_checkpoint import copy_weights_llama from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM + from lit_gpt import GPT, Config + from scripts.convert_lit_checkpoint import copy_weights_llama + ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86) T = 5 theirs_config = LlamaConfig( @@ -190,11 +194,11 @@ def test_against_hf_phi(): if not file_path.is_file(): urlretrieve(url=url, filename=file_path) + from original_phi_1_5 import MixFormerSequentialConfig, MixFormerSequentialForCausalLM + from lit_gpt import GPT, Config from scripts.convert_lit_checkpoint import copy_weights_phi - from original_phi_1_5 import MixFormerSequentialConfig, MixFormerSequentialForCausalLM - ours_config = Config.from_name( "phi-1_5", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5 ) diff --git a/tests/test_gptq.py b/tests/test_gptq.py index 0813a1776f..2896bcffe6 100644 --- a/tests/test_gptq.py +++ b/tests/test_gptq.py @@ -1,7 +1,6 @@ import lightning as L import pytest import torch - from conftest import RunIf diff --git a/tests/test_lora.py b/tests/test_lora.py index ebe4431688..86cae056bd 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -1,13 +1,20 @@ +import sys from contextlib import redirect_stdout from io import StringIO from itertools import product +from pathlib import Path from unittest.mock import Mock import pytest import torch +from conftest import RunIf from lightning import Fabric -from conftest import RunIf +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.config as config_module def test_lora_layer_replacement(): @@ -305,9 +312,10 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap ], ) def test_lora_qkv_linear_compare_conv1d(n_head, enable_lora): - from lit_gpt.lora import LoRAQKVLinear from torch.nn import functional as F + from lit_gpt.lora import LoRAQKVLinear + C = 12 layer = LoRAQKVLinear(C, 3 * C, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora) x = torch.randn((1, 1, C)) @@ -431,24 +439,17 @@ def test_lora_gpt_init_weights(): assert (param == 0).all() -def test_base_model_can_be_lora_loaded(): +@pytest.mark.parametrize("name", [c["name"] for c in config_module.configs]) +def test_base_model_can_be_lora_loaded(name): from lit_gpt.lora import GPT as LoRAGPT from lit_gpt.lora import lora_filter from lit_gpt.model import GPT as BaseGPT - base_model = BaseGPT.from_name("pythia-70m", bias=True, n_layer=2) + kwargs = {"n_layer": 2, "n_head": 8, "n_embd": 16, "padded_vocab_size": 32} + base_model = BaseGPT.from_name(name, **kwargs) base_model_state_dict = base_model.state_dict() lora_model = LoRAGPT.from_name( - "pythia-70m", - bias=True, - n_layer=2, - r=1, - to_query=True, - to_key=True, - to_value=True, - to_projection=True, - to_mlp=True, - to_head=True, + name, **kwargs, r=1, to_query=True, to_key=True, to_value=True, to_projection=True, to_mlp=True, to_head=True ) keys = lora_model.load_state_dict(base_model_state_dict, strict=False) assert not keys.unexpected_keys diff --git a/tests/test_model.py b/tests/test_model.py index e3e84bcf0b..2175e8d8a8 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -5,9 +5,8 @@ import pytest import torch -from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_2 - from conftest import RunIf +from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_2 # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -38,9 +37,10 @@ ], ) def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residual, device, dtype) -> None: + from transformers import GPTNeoXConfig, GPTNeoXForCausalLM + from lit_gpt import GPT, Config from scripts.convert_hf_checkpoint import copy_weights_gpt_neox - from transformers import GPTNeoXConfig, GPTNeoXForCausalLM torch.set_default_dtype(dtype) @@ -110,9 +110,10 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua ], ) def test_against_hf_falcon(kwargs, device, dtype): + from transformers.models.falcon import FalconConfig, FalconForCausalLM + from lit_gpt import GPT, Config from scripts.convert_hf_checkpoint import copy_weights_falcon - from transformers.models.falcon import FalconConfig, FalconForCausalLM torch.set_default_dtype(dtype) @@ -160,11 +161,12 @@ def test_against_hf_falcon(kwargs, device, dtype): ], ) def test_against_original_open_llama_3b(device, dtype): - from lit_gpt import GPT, Config - from scripts.convert_hf_checkpoint import copy_weights_hf_llama from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM + from lit_gpt import GPT, Config + from scripts.convert_hf_checkpoint import copy_weights_hf_llama + torch.set_default_dtype(dtype) ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86) @@ -215,11 +217,12 @@ def test_against_original_open_llama_3b(device, dtype): ], ) def test_against_hf_llama2(ours_kwargs, device, dtype): - from lit_gpt import GPT, Config - from scripts.convert_hf_checkpoint import copy_weights_hf_llama from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM + from lit_gpt import GPT, Config + from scripts.convert_hf_checkpoint import copy_weights_hf_llama + torch.set_default_dtype(dtype) ours_config = Config.from_name( @@ -327,11 +330,12 @@ def test_against_hf_phi(device, dtype): ], ) def test_against_hf_mistral(device, dtype): - from lit_gpt import GPT, Config - from scripts.convert_hf_checkpoint import copy_weights_hf_llama from transformers.models.mistral.configuration_mistral import MistralConfig from transformers.models.mistral.modeling_mistral import MistralForCausalLM + from lit_gpt import GPT, Config + from scripts.convert_hf_checkpoint import copy_weights_hf_llama + torch.set_default_dtype(dtype) ours_config = Config.from_name( @@ -456,9 +460,10 @@ def test_model_kv_cache_amp(): @pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs]) @torch.inference_mode() def test_sdpa_choice(config): - from lit_gpt import GPT from torch.backends.cuda import SDPBackend + from lit_gpt import GPT + torch.set_default_dtype(torch.float16) def assert_sdpa_uses_flash(original_fn, q, k, v, mask): @@ -500,9 +505,10 @@ def assert_sdpa_uses_flash(original_fn, q, k, v, mask): @pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs]) @torch.inference_mode() def test_sdpa_choice_kv_cache(config): - from lit_gpt import GPT from torch.backends.cuda import SDPBackend + from lit_gpt import GPT + torch.set_default_dtype(torch.float16) def assert_sdpa_uses_flash(original_fn, q, k, v, mask): diff --git a/tests/test_rope.py b/tests/test_rope.py index e8f6c4e0ee..076620737f 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -3,9 +3,10 @@ @torch.inference_mode() def test_rope(): - from lit_gpt.model import apply_rope, build_rope_cache from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding, apply_rotary_pos_emb + from lit_gpt.model import apply_rope, build_rope_cache + bs, seq_len, n_head, n_embed = 1, 6, 2, 8 head_size = n_embed // n_head x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float() diff --git a/tests/test_utils.py b/tests/test_utils.py index 318832ee23..a704267b08 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,9 +5,8 @@ import pytest import torch import torch.nn.functional as F -from lightning import Fabric - from conftest import RunIf +from lightning import Fabric def test_find_multiple(): @@ -144,6 +143,7 @@ def test_num_parameters(): @pytest.mark.skip("To be fixed") def test_num_parameters_bitsandbytes(mode): from lightning.fabric.plugins import BitsandbytesPrecision + from lit_gpt import GPT from lit_gpt.utils import num_parameters