Skip to content

Commit

Permalink
Support llama3.1-8b generation (pytorch#947)
Browse files Browse the repository at this point in the history
* add llama 3.1 8b support

* replace math.pi with torch.pi

* add 3.1 8b base and 70b
  • Loading branch information
Gasoonjia authored Jul 24, 2024
1 parent 63cb0a0 commit 3e28e5d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 1 deletion.
1 change: 1 addition & 0 deletions build/known_model_params/Meta-Llama-3.1-70B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true}
1 change: 1 addition & 0 deletions build/known_model_params/Meta-Llama-3.1-8B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true}
31 changes: 30 additions & 1 deletion build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import json
import os

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional
Expand Down Expand Up @@ -38,6 +39,7 @@ class ModelArgs:
ffn_dim_multiplier: Optional[int] = None
use_tiktoken: bool = False
max_seq_length: int = 8192
use_scaled_rope: bool = False

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down Expand Up @@ -178,6 +180,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
self.config.dim // self.config.n_heads,
self.config.block_size * 2,
self.config.rope_base,
use_scaled = self.config.use_scaled_rope,
)
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
causal_mask = torch.tril(
Expand Down Expand Up @@ -361,15 +364,41 @@ def forward(self, x: Tensor) -> Tensor:
return output * self.weight


def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * torch.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

def precompute_freqs_cis(
n_elem: int, seq_len: int, base: int = 10000, dtype=None
n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False
) -> Tensor:
if not dtype:
dtype = get_precision()
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
t = torch.arange(seq_len, device=freqs.device)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
Expand Down
17 changes: 17 additions & 0 deletions config/data/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@
"distribution_path": "meta-llama/Meta-Llama-3-70B-Instruct",
"transformer_params_key": "Meta-Llama-3-70B"
},
"meta-llama/Meta-Llama-3.1-8B": {
"aliases": ["llama3.1-base"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Meta-Llama-3.1-8B"
},
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
"aliases": ["llama3.1", "llama3.1-chat", "llama3.1-instruct"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-8B"
},
"meta-llama/Meta-Llama-3.1-70B-Instruct": {
"aliases": ["llama3.1-70b"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-70B"
},
"meta-llama/CodeLlama-7b-Python-hf": {
"aliases": ["codellama", "codellama-7b"],
"distribution_channel": "HuggingFaceSnapshot",
Expand Down

0 comments on commit 3e28e5d

Please sign in to comment.