Skip to content

Commit

Permalink
Some improvements for KV caching (Lightning-AI#1891)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrei-Aksionov <[email protected]>
  • Loading branch information
mseeger and Andrei-Aksionov authored Dec 31, 2024
1 parent 470f14e commit 17a58df
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 363 deletions.
78 changes: 11 additions & 67 deletions litgpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -28,56 +28,27 @@ class Config(BaseConfig):


class GPT(BaseModel):
"""The implementation is identical to `litgpt.model.GPT` with the exception that
the `Block` saves the layer index and passes it down to the attention layer."""

# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config

self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.lm_head = nn.Linear(
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
h=nn.ModuleList(
Block(config, block_idx)
for block_idx in range(config.n_layer)
),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.max_seq_length = self.config.block_size
self.mask_cache: Optional[torch.Tensor] = None

def forward(
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
) -> Union[torch.Tensor, List[torch.Tensor]]:
T = idx.size(1)
if self.max_seq_length < T:
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")

if input_pos is not None: # use the kv cache
cos = self.cos.index_select(0, input_pos)
sin = self.sin.index_select(0, input_pos)
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = self.mask_cache.index_select(2, input_pos)
else:
cos = self.cos[:T]
sin = self.sin[:T]
mask = None

x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if self.config.scale_embeddings:
x = x * (self.config.n_embd**0.5)
for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
if lm_head_chunk_size > 0:
# chunk the lm head logits to reduce the peak memory used by autograd
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
x = self.lm_head(x) # (b, t, vocab_size)
if self.config.final_logit_softcapping is not None:
x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
return x
self.max_seq_length = self.config.block_size

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
Expand All @@ -91,30 +62,9 @@ def _init_weights(self, module: nn.Module) -> None:


class Block(BaseBlock):
"""The implementation is identical to `litgpt.model.Block` with the exception that
we replace the attention layer where adaption is implemented."""

def __init__(self, config: Config, block_idx: int) -> None:
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
nn.Module.__init__(self)
if not config.parallel_residual and config.shared_attention_norm:
raise NotImplementedError(
"No checkpoint amongst the ones we support uses this configuration:"

" non-parallel residual and shared attention norm."
)
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
self.post_attention_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()
)
self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
self.mlp = config.mlp_class(config)
self.post_mlp_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()
)

self.config = config


class CausalSelfAttention(BaseCausalSelfAttention):
Expand All @@ -130,12 +80,6 @@ def __init__(self, config: Config, block_idx: int) -> None:
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
# kv cache for inference
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
self.block_idx = block_idx
self.apply_sliding_window_attention = (
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
)
self.config = config

def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
Expand Down
126 changes: 37 additions & 89 deletions litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Type, Optional

import torch
import torch.nn as nn
from typing_extensions import Self

import litgpt
from litgpt.adapter import GPT as BaseModel
from litgpt.adapter import Block as BaseBlock
from litgpt.model import Block as BaseBlock
from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.adapter import Config as BaseConfig
from litgpt.model import KVCache
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights

Expand Down Expand Up @@ -64,54 +63,27 @@ def reset_parameters(self) -> None:


class GPT(BaseModel):
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
def __init__(self, config: Config) -> None:
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config

self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.lm_head = AdapterV2Linear(
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
h=nn.ModuleList(
Block(config, block_idx)
for block_idx in range(config.n_layer)
),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.max_seq_length = self.config.block_size
self.mask_cache: Optional[torch.Tensor] = None

def forward(
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
) -> Union[torch.Tensor, List[torch.Tensor]]:
T = idx.size(1)
if self.max_seq_length < T:
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")

if input_pos is not None: # use the kv cache
cos = self.cos.index_select(0, input_pos)
sin = self.sin.index_select(0, input_pos)
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = self.mask_cache.index_select(2, input_pos)
else:
cos = self.cos[:T]
sin = self.sin[:T]
mask = None

x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if self.config.scale_embeddings:
x = x * (self.config.n_embd**0.5)
for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
if lm_head_chunk_size > 0:
# chunk the lm head logits to reduce the peak memory used by autograd
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
x = self.lm_head(x) # (b, t, vocab_size)
if self.config.final_logit_softcapping is not None:
x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
return x
self.max_seq_length = self.config.block_size

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
Expand All @@ -131,61 +103,30 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa


class Block(BaseBlock):
"""The implementation is identical to `litgpt.model.Block` with the exception that
we replace the attention layer where adaption is implemented."""

def __init__(self, config: Config, block_idx: int) -> None:
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
nn.Module.__init__(self)
if not config.parallel_residual and config.shared_attention_norm:
raise NotImplementedError(
"No checkpoint amongst the ones we support uses this configuration:"
" non-parallel residual and shared attention norm."
)
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
self.post_attention_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()
)
self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
self.mlp = config.mlp_class(config)
self.post_mlp_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()
)

self.config = config


class CausalSelfAttention(BaseCausalSelfAttention):
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""

# Copy&paste from :class:`model.CausalSelfAttention`
def __init__(self, config: Config, block_idx: int) -> None:
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
super().__init__(config, block_idx)
# key, query, value projections for all heads, but in a batch
self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
self.qkv = AdapterV2Linear(
in_features=config.n_embd,
out_features=shape,
bias=config.bias or config.attn_bias
)
# output projection
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
# disabled by default
self.kv_cache: Optional[KVCache] = None

if block_idx >= config.adapter_start_layer:
# adapter embedding layer
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
# gate for adaption
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
# kv cache for inference
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
self.block_idx = block_idx
self.apply_sliding_window_attention = (
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
self.proj = AdapterV2Linear(
config.head_size * config.n_head, config.n_embd, bias=config.bias
)

self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base and/or legacy checkpoints."""
mapping = {
Expand All @@ -211,9 +152,12 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
class GptNeoxMLP(litgpt.model.GptNeoxMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)

self.fc = AdapterV2Linear(
config.n_embd, config.intermediate_size, bias=config.bias
)
self.proj = AdapterV2Linear(
config.intermediate_size, config.n_embd, bias=config.bias
)
self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
Expand All @@ -231,10 +175,15 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
class LLaMAMLP(litgpt.model.LLaMAMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)

self.fc_1 = AdapterV2Linear(
config.n_embd, config.intermediate_size, bias=config.bias
)
self.fc_2 = AdapterV2Linear(
config.n_embd, config.intermediate_size, bias=config.bias
)
self.proj = AdapterV2Linear(
config.intermediate_size, config.n_embd, bias=config.bias
)
self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -264,7 +213,6 @@ def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False)
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))

self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
Expand Down
28 changes: 23 additions & 5 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from pathlib import Path
from pprint import pprint
from typing import Any, Literal, Optional, Tuple, List, Union, Iterator
from typing import Any, Literal, Optional, Tuple, List, Union, Iterator, Dict
import warnings

import lightning as L
Expand Down Expand Up @@ -73,15 +73,23 @@ def sample(
return torch.argmax(logits, dim=-1, keepdim=True)


def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
logits = model(x, input_pos)
_next = sample(logits, **kwargs).to(dtype=torch.int64)
def next_token(
model: GPT,
input_pos: torch.Tensor,
x: torch.Tensor,
input_pos_maxp1: Optional[torch.Tensor] = None,
**sample_kwargs: Dict[str, Any],
) -> torch.Tensor:
logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1)
_next = sample(logits, **sample_kwargs).to(dtype=torch.int64)
return _next


def batched_sample(logits: list[torch.Tensor], kwargs: list[dict]) -> torch.Tensor:
assert len(logits) == len(kwargs), "logits and kwargs must have the same length."
return torch.stack([sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(kwargs, logits)], dim=0)


def batched_next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, kwargs: Union[dict, list[dict]]) -> torch.Tensor:
# Where:
# input_pos is a 1d tensor of shape [seq_length...]
Expand Down Expand Up @@ -166,10 +174,19 @@ def generate_fn(
token = prompt
prefill_token = True
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
input_pos_maxp1 = torch.tensor(prompt_size, device=device)
for current_idx in range(max_returned_tokens - prompt_size):

# Generate the token
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p)
token = next_token(
model,
input_pos,
token.view(1, -1),
input_pos_maxp1=input_pos_maxp1,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
tokens.append(token)
int_token = token.item()

Expand Down Expand Up @@ -205,6 +222,7 @@ def generate_fn(
input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)
else:
input_pos.add_(1)
input_pos_maxp1.add_(1)

# Yield any remaining tokens
if yielded_idx < len(tokens):
Expand Down
Loading

0 comments on commit 17a58df

Please sign in to comment.