Skip to content

Commit

Permalink
Apply proposed changes by @RDouglasSharp from issue Lightning-AI#57 L…
Browse files Browse the repository at this point in the history
…ightning-AI#58 (Lightning-AI#60)

Co-authored-by: Luca Antiga <[email protected]>
  • Loading branch information
ArturK-85 and lantiga authored May 16, 2023
1 parent e09ad1d commit ed8c9f5
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
59 changes: 43 additions & 16 deletions lit_parrot/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Any
from typing import Optional, Tuple, Any, List

import torch
import torch.nn as nn
from torch.nn import functional as F
from typing_extensions import Self

from lit_parrot.config import Config as BaseConfig
from lit_parrot.model import MLP, Parrot as BaseModel, build_rope_cache, apply_rope
from lit_parrot.model import MLP, Parrot as BaseModel, build_rope_cache, apply_rope, RoPECache, KVCache


@dataclass
Expand Down Expand Up @@ -48,12 +48,19 @@ def __init__(self, config: Config, block_idx: int) -> None:
self.n_embd = config.n_embd
self.block_size = config.block_size
self.rotary_percentage = config.rotary_percentage
self.rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
self.block_idx = block_idx
self.adapter_prompt_length = config.adapter_prompt_length
self.adapter_start_layer = config.adapter_start_layer

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
mask: torch.Tensor,
max_seq_length: int,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

qkv = self.attn(x)
Expand All @@ -62,20 +69,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k, v = qkv.split(head_size, dim=-1) # (B, nh, T, hs)

n_elem = int(self.rotary_percentage * head_size)
if self.rope_cache is None:
self.rope_cache = build_rope_cache(self.block_size, n_elem, x.dtype, x.device)
cos, sin = self.rope_cache
cos, sin = cos[:T], sin[:T]

cos, sin = rope
q_roped = apply_rope(q[..., :n_elem], cos, sin)
k_roped = apply_rope(k[..., :n_elem], cos, sin)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)

if kv_cache is not None:
cache_k, cache_v = kv_cache
# check if reached token limit
if input_pos[-1] >= max_seq_length:
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
# shift 1 position to the left
cache_k = torch.roll(cache_k, -1, dims=2)
cache_v = torch.roll(cache_v, -1, dims=2)
k = cache_k.index_copy(2, input_pos, k)
v = cache_v.index_copy(2, input_pos, v)
kv_cache = k, v

# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, scale=1.0 / math.sqrt(head_size)
)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, scale=1.0 / math.sqrt(head_size))

if self.block_idx >= self.adapter_start_layer:
prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
Expand All @@ -94,7 +108,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# output projection
y = self.proj(y)

return y
return y, kv_cache


class Block(nn.Module):
Expand All @@ -110,13 +124,22 @@ def __init__(self, config: Config, block_idx: int) -> None:

self.parallel_residual = config.parallel_residual

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
mask: torch.Tensor,
max_seq_length: int,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
h, new_kv_cache = self.attn(self.norm_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
if self.parallel_residual:
x = x + self.attn(self.norm_1(x)) + self.mlp(self.norm_2(x))
x = x + h + self.mlp(self.norm_2(x))
else:
x = x + self.attn(self.norm_1(x))
x = x + h
x = x + self.mlp(self.norm_2(x))
return x
return x, new_kv_cache


class Parrot(BaseModel):
Expand All @@ -137,6 +160,10 @@ def __init__(self, config: Config) -> None:
)
)

self.rope_cache: Optional[RoPECache] = None
self.mask_cache: Optional[torch.Tensor] = None
self.kv_caches: List[KVCache] = []

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
Expand Down
4 changes: 3 additions & 1 deletion lit_parrot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __post_init__(self):

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(**configs[name], **kwargs)
conf_dict = configs[name].copy()
conf_dict.update(kwargs)
return cls(**conf_dict)


# fmt: off
Expand Down
2 changes: 1 addition & 1 deletion lit_parrot/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def forward(
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)

if kv_cache is not None:
if input_pos is not None and kv_cache is not None:
cache_k, cache_v = kv_cache
# check if reached token limit
if input_pos[-1] >= max_seq_length:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pathlib import Path


wd = Path(__file__).parent.parent.absolute()


def test_config():
from lit_parrot import Config

config = Config()
assert config.block_size == 4096

config = Config(block_size=2048)
assert config.block_size == 2048

config = Config.from_name("pythia-70m")
assert config.block_size == 2048

config = Config.from_name("pythia-70m", block_size=4096)
assert config.block_size == 4096

0 comments on commit ed8c9f5

Please sign in to comment.