Skip to content

Commit

Permalink
Add Phi 1.5 support (Lightning-AI#569)
Browse files Browse the repository at this point in the history
Co-authored-by: rasbt <[email protected]>
  • Loading branch information
carmocca and rasbt authored Sep 20, 2023
1 parent 966646a commit 2cee32f
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ out
wandb

tests/original_falcon_40b.py
tests/original_phi_1_5.py
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Supports the following popular model checkpoints:
| Platypus | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| NousResearch Nous-Hermes | [Org page](https://huggingface.co/NousResearch) |
| Meta AI [Code Llama](tutorials/download_code_llama.md) | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Microsoft Research [phi-1.5](tutorials/download_phi15.md) | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |

This implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/)**.

Expand Down
5 changes: 5 additions & 0 deletions chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search("phi", checkpoint_name):
system_prompt = "{prompt}\n\nAnswer:"
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

# default format
return "{prompt}", ([tokenizer.eos_id],)

Expand Down
2 changes: 1 addition & 1 deletion lit_gpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, config: Config) -> None:
assert config.padded_vocab_size is not None
self.config = config

self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
Expand Down
4 changes: 3 additions & 1 deletion lit_gpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, config: Config) -> None:
assert config.padded_vocab_size is not None
self.config = config

self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
Expand Down Expand Up @@ -156,6 +156,8 @@ def __init__(self, config: Config) -> None:
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)

self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
Expand Down
24 changes: 24 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Config:
rotary_percentage: float = 0.25
parallel_residual: bool = True
bias: bool = True
lm_head_bias: bool = False
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
Expand All @@ -49,6 +50,7 @@ class Config:
_norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
_mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
rope_base: int = 10000
Expand Down Expand Up @@ -984,4 +986,26 @@ def norm_class(self) -> Type:
]
configs.extend(together_llama2_32k)


################
# Microsoft Phi
################
phi = [
# https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
dict(
org="microsoft",
name="phi-1_5",
vocab_size=50257,
padded_vocab_size=51200,
block_size=2048,
n_embd=2048,
n_layer=24,
rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64
shared_attention_norm=True,
lm_head_bias=True,
gelu_approximate="tanh",
)
]
configs.extend(phi)

name_to_config = {config["name"]: config for config in configs}
4 changes: 3 additions & 1 deletion lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def __init__(self, config: Config) -> None:
self.lm_head = LoRALinear(
config.n_embd,
config.padded_vocab_size,
bias=False,
bias=config.lm_head_bias,
r=(config.r if config.to_head else 0),
lora_alpha=config.alpha,
lora_dropout=config.dropout,
Expand Down Expand Up @@ -596,6 +596,8 @@ def __init__(self, config: Config) -> None:
lora_dropout=config.dropout,
)

self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
Expand Down
6 changes: 4 additions & 2 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, config: Config) -> None:
assert config.padded_vocab_size is not None
self.config = config

self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
Expand Down Expand Up @@ -283,9 +283,11 @@ def __init__(self, config: Config) -> None:
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)

self.config = config

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
x = torch.nn.functional.gelu(x)
x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
return self.proj(x)


Expand Down
64 changes: 59 additions & 5 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,68 @@ def copy_weights_hf_llama(
q = load_param(q, f"layer {i} q", dtype)
k = load_param(k, f"layer {i} k", dtype)
v = load_param(v, f"layer {i} v", dtype)
qkv = torch.cat((q, k, v))
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
qkv = qkv.view(total_qkv, config.n_query_groups, -1).transpose(0, 1)
qkv = qkv.reshape(config.n_embd * 3, -1)
state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv
del qkv_weights[i]


def copy_weights_phi(
config: Config,
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
weight_map = {
"layers.0.wte.weight": "transformer.wte.weight",
"layers.{}.ln.bias": "transformer.h.{}.norm_1.bias",
"layers.{}.ln.weight": "transformer.h.{}.norm_1.weight",
"layers.{}.mixer.Wqkv.bias": "transformer.h.{}.attn.attn.bias",
"layers.{}.mixer.Wqkv.weight": "transformer.h.{}.attn.attn.weight",
"layers.{}.mixer.out_proj.bias": "transformer.h.{}.attn.proj.bias",
"layers.{}.mixer.out_proj.weight": "transformer.h.{}.attn.proj.weight",
"layers.{}.mixer.rotary_emb.inv_freq": None,
"layers.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias",
"layers.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight",
"layers.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias",
"layers.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight",
f"layers.{config.n_layer + 1}.ln.bias": "transformer.ln_f.bias",
f"layers.{config.n_layer + 1}.ln.weight": "transformer.ln_f.weight",
f"layers.{config.n_layer + 1}.linear.weight": "lm_head.weight",
f"layers.{config.n_layer + 1}.linear.bias": "lm_head.bias",
}

for name, param in hf_weights.items():
if "layers" in name:
from_name, number = layer_template(name, 1)
if number in (0, config.n_layer + 1):
# these are part of the layers in phi, but not in our implementation
to_name = weight_map[name]
else:
to_name = weight_map[from_name]
if to_name is None:
continue
# the phi layer numbering is off by 1 compared to ours
to_name = to_name.format(number - 1)
else:
to_name = weight_map[name]
param = load_param(param, name, dtype)
if "Wqkv" in name:
q_per_kv = config.n_head // config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
param = param.view(total_qkv, config.n_query_groups, -1).transpose(0, 1)
param = param.reshape(config.n_embd * 3, -1)
if "bias" in name:
param = param.squeeze()
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param


def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
split = layer_name.split(".")
number = int(split[idx])
Expand Down Expand Up @@ -214,6 +266,8 @@ def convert_hf_checkpoint(
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi" in model_name:
copy_fn = partial(copy_weights_phi, config)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
57 changes: 45 additions & 12 deletions scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,31 +124,64 @@ def copy_weights_llama(
k = "model.layers.{}.self_attn.k_proj.weight".format(number)
v = "model.layers.{}.self_attn.v_proj.weight".format(number)
qkv = load_param(param, name, None)
qp, kp, vp = tensor_split(qkv, config)
qp, kp, vp = qkv_split(qkv, config)
for to_name, param in zip((q, k, v), (qp, kp, vp)):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
elif "transformer.h" in name:
from_name, number = layer_template(name, 2)
to_name = weight_map[from_name]
if to_name is None:
continue
to_name = to_name.format(number)
else:
if "transformer.h" in name:
from_name, number = layer_template(name, 2)
to_name = weight_map[from_name]
to_name = to_name.format(number)
else:
to_name = weight_map[name]
param = load_param(param, name, None)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param


def copy_weights_phi(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "layers.0.wte.weight",
"transformer.h.{}.norm_1.bias": "layers.{}.ln.bias",
"transformer.h.{}.norm_1.weight": "layers.{}.ln.weight",
"transformer.h.{}.attn.attn.bias": "layers.{}.mixer.Wqkv.bias",
"transformer.h.{}.attn.attn.weight": "layers.{}.mixer.Wqkv.weight",
"transformer.h.{}.attn.proj.bias": "layers.{}.mixer.out_proj.bias",
"transformer.h.{}.attn.proj.weight": "layers.{}.mixer.out_proj.weight",
"transformer.h.{}.mlp.fc.bias": "layers.{}.mlp.fc1.bias",
"transformer.h.{}.mlp.fc.weight": "layers.{}.mlp.fc1.weight",
"transformer.h.{}.mlp.proj.bias": "layers.{}.mlp.fc2.bias",
"transformer.h.{}.mlp.proj.weight": "layers.{}.mlp.fc2.weight",
"transformer.ln_f.bias": f"layers.{config.n_layer + 1}.ln.bias",
"transformer.ln_f.weight": f"layers.{config.n_layer + 1}.ln.weight",
"lm_head.weight": f"layers.{config.n_layer + 1}.linear.weight",
"lm_head.bias": f"layers.{config.n_layer + 1}.linear.bias",
}

for name, param in lit_weights.items():
if "transformer.h" in name:
from_name, number = layer_template(name, 2)
to_name = weight_map[from_name]
to_name = to_name.format(number + 1)
else:
to_name = weight_map[name]
param = load_param(param, name, None)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
param = load_param(param, name, None)
if "attn.attn." in name:
param = torch.cat(qkv_split(param, config))
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param


def tensor_split(
def qkv_split(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q_per_kv = config.n_head // config.n_query_groups
Expand Down
Loading

0 comments on commit 2cee32f

Please sign in to comment.