Skip to content

Commit

Permalink
Add CodeLlama support (Lightning-AI#472)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastian Raschka <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored Sep 1, 2023
1 parent ce39a99 commit b3cdb6e
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 54 deletions.
29 changes: 15 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,21 @@ Hackable [implementation](lit_gpt/model.py) of state-of-the-art open-source larg

Supports the following popular model checkpoints:

| Model and usage | Reference |
|--------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|
| Meta AI [Llama 2](tutorials/download_llama_2.md) | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Stability AI [FreeWilly2](tutorials/download_freewilly_2.md) | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Stability AI StableCode | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TII UAE [Falcon](tutorials/download_falcon.md) | [TII 2023](https://falconllm.tii.ae) |
| OpenLM Research [OpenLLaMA](tutorials/download_openllama.md) | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| LMSYS [Vicuna](tutorials/download_vicuna.md) | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |
| LMSYS [LongChat](tutorials/download_longchat.md) | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Together [RedPajama-INCITE](tutorials/download_redpajama_incite.md) | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| EleutherAI [Pythia](tutorials/download_pythia.md) | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| StabilityAI [StableLM](tutorials/download_stablelm.md) | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| Platypus | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| NousResearch Nous-Hermes | [Org page](https://huggingface.co/NousResearch) |
| Model and usage | Reference |
|---------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|
| Meta AI [Llama 2](tutorials/download_llama_2.md) | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Stability AI [FreeWilly2](tutorials/download_freewilly_2.md) | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Stability AI StableCode | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TII UAE [Falcon](tutorials/download_falcon.md) | [TII 2023](https://falconllm.tii.ae) |
| OpenLM Research [OpenLLaMA](tutorials/download_openllama.md) | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| LMSYS [Vicuna](tutorials/download_vicuna.md) | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |
| LMSYS [LongChat](tutorials/download_longchat.md) | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Together [RedPajama-INCITE](tutorials/download_redpajama_incite.md) | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| EleutherAI [Pythia](tutorials/download_pythia.md) | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| StabilityAI [StableLM](tutorials/download_stablelm.md) | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| Platypus | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| NousResearch Nous-Hermes | [Org page](https://huggingface.co/NousResearch) |
| Meta AI [Code Llama](tutorials/download_code_llama.md) | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |

This implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/)**.

Expand Down
8 changes: 8 additions & 0 deletions chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search("CodeLlama", checkpoint_name):
# we don't set a default system prompt, but it is supported:
# https://huggingface.co/blog/codellama#conversational-instructions
b_inst, e_inst = "<s>[INST]", "[/INST]"
system_prompt = f"{b_inst} {{prompt}} {e_inst}"
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

# default format
return "{prompt}", ([tokenizer.eos_id],)

Expand Down
196 changes: 174 additions & 22 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Config:
_mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
intermediate_size: Optional[int] = None
condense_ratio: int = 1
rope_base: int = 10000

def __post_init__(self):
# error checking
Expand Down Expand Up @@ -154,7 +155,6 @@ def norm_class(self) -> Type:
block_size=2048,
n_layer=32,
n_embd=2560,
n_head=32,
padding_multiple=256,
rotary_percentage=1.0,
parallel_residual=False,
Expand All @@ -165,7 +165,6 @@ def norm_class(self) -> Type:
name="RedPajama-INCITE-7B-{}",
block_size=2048,
n_layer=32,
n_head=32,
padding_multiple=256,
rotary_percentage=1.0,
parallel_residual=False,
Expand All @@ -176,7 +175,6 @@ def norm_class(self) -> Type:
name="RedPajama-INCITE-{}-7B-v0.1",
block_size=2048,
n_layer=32,
n_head=32,
padding_multiple=256,
rotary_percentage=1.0,
parallel_residual=False,
Expand Down Expand Up @@ -241,7 +239,6 @@ def norm_class(self) -> Type:
vocab_size=32000,
padding_multiple=64,
n_layer=26,
n_head=32,
n_embd=3200,
rotary_percentage=1.0,
parallel_residual=False,
Expand All @@ -259,7 +256,6 @@ def norm_class(self) -> Type:
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
Expand Down Expand Up @@ -302,7 +298,6 @@ def norm_class(self) -> Type:
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
Expand Down Expand Up @@ -353,7 +348,6 @@ def norm_class(self) -> Type:
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
Expand All @@ -368,7 +362,6 @@ def norm_class(self) -> Type:
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
Expand Down Expand Up @@ -405,7 +398,6 @@ def norm_class(self) -> Type:
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
condense_ratio=4,
Expand All @@ -426,7 +418,6 @@ def norm_class(self) -> Type:
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
Expand Down Expand Up @@ -467,11 +458,8 @@ def norm_class(self) -> Type:
dict(
org="NousResearch",
name="Nous-Hermes-llama-2-7b",
block_size=4096,
padded_vocab_size=32000,
n_layer=32,
n_head=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
Expand Down Expand Up @@ -501,7 +489,6 @@ def norm_class(self) -> Type:
dict(
org="NousResearch",
name="Nous-Hermes-Llama2-13b",
block_size=4096,
padded_vocab_size=32032,
n_layer=40,
n_head=40,
Expand Down Expand Up @@ -529,7 +516,6 @@ def norm_class(self) -> Type:
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
Expand Down Expand Up @@ -603,6 +589,179 @@ def norm_class(self) -> Type:
configs.extend(freewilly_2)


code_llama = [
# https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
dict(
org="codellama",
name="CodeLlama-7b-hf",
block_size=16384,
vocab_size=32016,
padding_multiple=16,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
dict(
org="codellama",
name="CodeLlama-13b-hf",
block_size=16384,
vocab_size=32016,
padding_multiple=16,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
dict(
org="codellama",
name="CodeLlama-34b-hf",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=48,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=22016,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
dict(
org="codellama",
name="CodeLlama-7b-Python-hf",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
dict(
org="codellama",
name="CodeLlama-13b-Python-hf",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
dict(
org="codellama",
name="CodeLlama-34b-Python-hf",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=48,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=22016,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json
dict(
org="codellama",
name="CodeLlama-7b-Instruct-hf",
block_size=16384,
vocab_size=32016,
padding_multiple=16,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
dict(
org="codellama",
name="CodeLlama-13b-Instruct-hf",
block_size=2048,
vocab_size=32016,
padding_multiple=16,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
dict(
org="codellama",
name="CodeLlama-34b-Instruct-hf",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=48,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
intermediate_size=22016,
rope_base=1000000,
),
]
configs.extend(code_llama)


########################
# garage-bAInd Platypus
########################
Expand Down Expand Up @@ -630,7 +789,6 @@ def norm_class(self) -> Type:
name="Platypus2-7B",
padded_vocab_size=32000,
n_layer=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
Expand Down Expand Up @@ -744,31 +902,25 @@ def norm_class(self) -> Type:
dict(
org="stabilityai",
name="stablecode-completion-alpha-3b",
block_size=4096,
vocab_size=49152,
n_layer=32,
n_head=32,
n_embd=2560,
condense_ratio=4,
),
# https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
dict(
org="stabilityai",
name="stablecode-completion-alpha-3b-4k",
block_size=4096,
vocab_size=49152,
n_layer=32,
n_head=32,
n_embd=2560,
),
# https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
dict(
org="stabilityai",
name="stablecode-instruct-alpha-3b",
block_size=4096,
vocab_size=49152,
n_layer=32,
n_head=32,
n_embd=2560,
),
]
Expand Down
1 change: 1 addition & 0 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
dtype=torch.get_default_dtype(),
device=idx.device,
condense_ratio=self.config.condense_ratio,
base=self.config.rope_base,
)

def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
Expand Down
Loading

0 comments on commit b3cdb6e

Please sign in to comment.