From 9ebbe823290a3ccca850283b884b051e69e6656f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 24 Jul 2024 08:08:13 -0700 Subject: [PATCH] Add Llama 3.1 405B config (#1622) Co-authored-by: Sebastian Raschka --- README.md | 2 +- litgpt/config.py | 21 +++++++++++++++++++++ tests/test_model.py | 1 + tests/test_prompts.py | 1 + tutorials/download_model_weights.md | 5 ++++- 5 files changed, 28 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 29f63ce06e..810f140c36 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Every model is written from scratch to maximize performance and remove layers of | Model | Model size | Author | Reference | |----|----|----|----| -| Llama 3 & 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | +| Llama 3 & 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | diff --git a/litgpt/config.py b/litgpt/config.py index 0dfe986579..f819cf51ea 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -877,6 +877,7 @@ def norm_class(self) -> Type: intermediate_size=14336, rope_base=500000, ), + # https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/config.json dict( name="Llama-3.1-8B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-8B{}"), @@ -913,6 +914,7 @@ def norm_class(self) -> Type: intermediate_size=28672, rope_base=500000, ), + # https://huggingface.co/meta-llama/Meta-Llama-3.1-70B/blob/main/config.json dict( name="Llama-3.1-70B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-70B{}"), @@ -931,6 +933,25 @@ def norm_class(self) -> Type: intermediate_size=28672, rope_base=500000, ), + # https://huggingface.co/meta-llama/Meta-Llama-3.1-405B/blob/main/config.json + dict( + name="Llama-3.1-405B{}", + hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-405B{}"), + block_size=131072, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=126, + n_head=128, + n_embd=16384, + n_query_groups=16, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=53248, + rope_base=500000, + ), ] for c in llama_3: for kind in ("", "-Instruct"): diff --git a/tests/test_model.py b/tests/test_model.py index 22177dc2fe..890d72119e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -213,6 +213,7 @@ def test_against_original_open_llama_3b(device, dtype): {"name": "Llama-2-70b-chat-hf", "n_query_groups": 1}, {"name": "Llama-3-8B"}, {"name": "Llama-3-8B-Instruct"}, + {"name": "Llama-3.1-405B", "n_query_groups": 4}, {"name": "Llama-3.1-8B"}, {"name": "Llama-3.1-8B-Instruct"}, ], diff --git a/tests/test_prompts.py b/tests/test_prompts.py index b10af5498a..65f4a0e71b 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -53,6 +53,7 @@ def test_prompt_style_from_config(): "Llama-2-70b-chat-hf", "Llama-3-8B-Instruct", "Llama-3-70B-Instruct", + "Llama-3.1-405B-Instruct", "Gemma-2b-it", "Gemma-7b-it", "FreeWilly2", diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 9a172fde7d..7872259b9b 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -17,7 +17,8 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | | Gemma 2 | 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) | | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | -| Llama 3 & 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | +| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | +| Llama 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) @@ -117,6 +118,8 @@ meta-llama/Meta-Llama-3-70B meta-llama/Meta-Llama-3-70B-Instruct meta-llama/Meta-Llama-3-8B meta-llama/Meta-Llama-3-8B-Instruct +meta-llama/Meta-Llama-3.1-405B +meta-llama/Meta-Llama-3.1-405B-Instruct meta-llama/Meta-Llama-3.1-70B meta-llama/Meta-Llama-3.1-70B-Instruct meta-llama/Meta-Llama-3.1-8B