diff --git a/build/known_model_params/Meta-Llama-3.1-70B.json b/build/known_model_params/Meta-Llama-3.1-70B.json new file mode 100644 index 000000000..d3e9a73fa --- /dev/null +++ b/build/known_model_params/Meta-Llama-3.1-70B.json @@ -0,0 +1 @@ +{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true} diff --git a/build/known_model_params/Meta-Llama-3.1-8B.json b/build/known_model_params/Meta-Llama-3.1-8B.json new file mode 100644 index 000000000..0d3808205 --- /dev/null +++ b/build/known_model_params/Meta-Llama-3.1-8B.json @@ -0,0 +1 @@ +{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true} diff --git a/build/model.py b/build/model.py index 0405e3683..b2a883ca5 100644 --- a/build/model.py +++ b/build/model.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import json import os + from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional @@ -38,6 +39,7 @@ class ModelArgs: ffn_dim_multiplier: Optional[int] = None use_tiktoken: bool = False max_seq_length: int = 8192 + use_scaled_rope: bool = False def __post_init__(self): if self.n_local_heads == -1: @@ -178,6 +180,7 @@ def setup_caches(self, max_batch_size, max_seq_length): self.config.dim // self.config.n_heads, self.config.block_size * 2, self.config.rope_base, + use_scaled = self.config.use_scaled_rope, ) self.register_buffer("freqs_cis", freqs_cis, persistent=True) causal_mask = torch.tril( @@ -361,8 +364,32 @@ def forward(self, x: Tensor) -> Tensor: return output * self.weight +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * torch.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + def precompute_freqs_cis( - n_elem: int, seq_len: int, base: int = 10000, dtype=None + n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False ) -> Tensor: if not dtype: dtype = get_precision() @@ -370,6 +397,8 @@ def precompute_freqs_cis( base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) ) t = torch.arange(seq_len, device=freqs.device) + if use_scaled: + freqs = apply_scaling(freqs) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) diff --git a/config/data/models.json b/config/data/models.json index 483f6c35f..ab0abb7d6 100644 --- a/config/data/models.json +++ b/config/data/models.json @@ -40,6 +40,23 @@ "distribution_path": "meta-llama/Meta-Llama-3-70B-Instruct", "transformer_params_key": "Meta-Llama-3-70B" }, + "meta-llama/Meta-Llama-3.1-8B": { + "aliases": ["llama3.1-base"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Meta-Llama-3.1-8B" + }, + "meta-llama/Meta-Llama-3.1-8B-Instruct": { + "aliases": ["llama3.1", "llama3.1-chat", "llama3.1-instruct"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "transformer_params_key": "Meta-Llama-3.1-8B" + }, + "meta-llama/Meta-Llama-3.1-70B-Instruct": { + "aliases": ["llama3.1-70b"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct", + "transformer_params_key": "Meta-Llama-3.1-70B" + }, "meta-llama/CodeLlama-7b-Python-hf": { "aliases": ["codellama", "codellama-7b"], "distribution_channel": "HuggingFaceSnapshot",