Skip to content

Commit

Permalink
Mathstral checkpoints (Lightning-AI#1587)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jul 16, 2024
1 parent d7d93a5 commit 5f5df90
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ Every model is written from scratch to maximize performance and remove layers of
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) |
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
Expand Down
20 changes: 20 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,26 @@ def norm_class(self) -> Type:
#############
# Mistral AI
#############

configs.append(
# https://huggingface.co/mistralai/mathstral-7B-v0.1/blob/main/config.json
dict(
name="Mathstral-7B-v0.1",
hf_config=dict(org="mistralai", name="mathstral-7B-v0.1"),
padded_vocab_size=32768,
block_size=32768,
n_layer=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
)
)

mistral = [
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
dict(
Expand Down
79 changes: 42 additions & 37 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,45 +386,50 @@ def test_against_hf_phi_3(model_name, device, dtype):
),
],
)
def test_against_hf_mistral(device, dtype):
def test_against_hf_models(device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
"Mistral-7B-Instruct-v0.1",
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = MistralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = MistralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
model_names = ["Mistral-7B-Instruct-v0.1", "Mathstral-7B-v0.1"]
for model_name in model_names:

ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
)

T = 5
theirs_config = MistralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
)

assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = MistralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
Expand Down
2 changes: 2 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) |
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
Expand Down Expand Up @@ -118,6 +119,7 @@ meta-llama/Meta-Llama-3-8B-Instruct
microsoft/phi-1_5
microsoft/phi-2
microsoft/Phi-3-mini-4k-instruct
mistralai/mathstral-7B-v0.1
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
mistralai/Mistral-7B-v0.1
Expand Down

0 comments on commit 5f5df90

Please sign in to comment.