Skip to content

Commit

Permalink
Fix LoRA with grouped queries (Lightning-AI#377)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2023
1 parent 1cbb416 commit ca21fc3
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 11 deletions.
58 changes: 49 additions & 9 deletions lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def __init__(
"""
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
self.n_head = n_head
self.n_query_groups = n_query_groups
if isinstance(enable_lora, bool):
enable_lora = [enable_lora] * 3
assert len(enable_lora) == 3
Expand All @@ -205,8 +207,14 @@ def __init__(
self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r * sum(enable_lora), in_features))) # (4, 128)
enable_q, enable_k, enable_v = enable_lora
self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups)
shape = self.linear.in_features * enable_q + self.kv_embd_size * enable_k + self.kv_embd_size * enable_v
self.lora_B = nn.Parameter(self.linear.weight.new_zeros(shape, r)) # (256, 2))
# qkv_shapes will be used to split a tensor with weights correctly
qkv_shapes = (
self.linear.in_features * enable_q,
self.kv_embd_size * enable_k,
self.kv_embd_size * enable_v,
)
self.qkv_shapes = [s for s in qkv_shapes if s]
self.lora_B = nn.Parameter(self.linear.weight.new_zeros(sum(self.qkv_shapes), r)) # (256, 2))
# Notes about shapes above
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
# 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
Expand Down Expand Up @@ -284,13 +292,46 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(0, 1)
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
result = result.view(-1, self.linear.out_features) # (4096, 384)
enable_q, enable_k, enable_v = self.enable_lora
shape = self.linear.in_features * enable_q + self.kv_embd_size * enable_k + self.kv_embd_size * enable_v
result = result.index_copy(
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, shape)
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
) # (4096, 256)
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)

def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
If the number of heads is equal to the number of query groups - grouped queries are disabled
(see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple
conv layers side by side).
Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
apply each part of the weight matrix to the corresponding input's part and concatenate the result.
Args:
input: input matrix of shape (B, C, T)
weight: weight matrix of shape (C_output, rank, 1).
"C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).
Returns:
A tensor with a shape (B, C_output, T)
"""
if self.n_head == self.n_query_groups:
return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T)

# Notation:
# ⚬ N: number of enabled LoRA layers (self.enable_lora)
# ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
# ⚬ r: rank of all LoRA layers (equal in size)

input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T)
weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1)
return torch.cat(
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T)
) # (B, C_output, T)

def merge(self):
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""

Expand All @@ -299,10 +340,9 @@ def merge(self):
# ⚬ self.lora_A.data: (4, 128)
# ⚬ self.lora_B.data: (256, 2)
if self.r > 0 and any(self.enable_lora) and not self.merged:
delta_w = F.conv1d(
delta_w = self.conv1d(
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
groups=sum(self.enable_lora),
).squeeze(
0
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
Expand Down Expand Up @@ -339,14 +379,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
# ⚬ groups: split input into groups, in_channels should be divisible by the number of groups. Default: 1
# presumably iW - sequence width/length, kW - kernel width
after_B = F.conv1d(
after_B = self.conv1d(
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
groups=sum(self.enable_lora),
).transpose(
-2, -1
) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
result += self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)

return result


Expand Down
73 changes: 71 additions & 2 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import redirect_stdout
from io import StringIO
from itertools import product
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -228,7 +229,7 @@ def test_lora_linear_utilization(apply_to, target_layer_names, mlp_class_name):
dropout=0.1,
_mlp_class=mlp_class_name,
intermediate_size=8 * 3,
**{apply_to: True}
**{apply_to: True},
)
model = GPT(config)
state_dict = model.state_dict()
Expand All @@ -248,8 +249,9 @@ def test_lora_linear_utilization(apply_to, target_layer_names, mlp_class_name):
assert lora_params == target_layer_names


@torch.inference_mode()
@pytest.mark.parametrize("apply_to", (None, "to_query", "to_key", "to_value", "to_projection", "to_mlp", "to_head"))
def test_lora_layer_forward_no_exception(apply_to):
def test_lora_gpt_apply_lora_forward_no_exception(apply_to):
from lit_gpt.lora import GPT, Config

config = Config(n_layer=1, n_head=4, n_embd=8, block_size=1, vocab_size=1, r=2, alpha=8, dropout=0.1)
Expand All @@ -262,6 +264,73 @@ def test_lora_layer_forward_no_exception(apply_to):
model(input_ids)


@torch.inference_mode()
@pytest.mark.parametrize("n_query_groups", (1, 2, 3, 6))
@pytest.mark.parametrize("apply_to", product((False, True), repeat=3))
def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, apply_to):
from lit_gpt.lora import GPT, Config, merge_lora_weights

keys = ("to_query", "to_key", "to_value")
values = apply_to
apply_to = dict(zip(keys, values))

config = Config(
n_layer=1,
n_head=6,
n_embd=12,
block_size=1,
vocab_size=1,
r=2,
alpha=8,
dropout=0.1,
n_query_groups=n_query_groups,
**apply_to,
)
model = GPT(config)
merge_lora_weights(model)
input_ids = torch.tensor([[1]])
model(input_ids)


@torch.inference_mode()
@pytest.mark.parametrize("n_head", (1, 2, 3, 6, 12))
@pytest.mark.parametrize(
"enable_lora",
[
(False, False, True),
(False, True, False),
(False, True, True),
(True, False, False),
(True, False, True),
(True, True, False),
(True, True, True),
],
)
def test_lora_qkv_linear_compare_conv1d(n_head, enable_lora):
from torch.nn import functional as F

from lit_gpt.lora import LoRAQKVLinear

C = 12
layer = LoRAQKVLinear(C, 3 * C, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora)
x = torch.randn((1, 1, C))
a = F.linear(x, layer.lora_A).transpose(-2, -1) # after_A
b = layer.lora_B.data.unsqueeze(-1)

# original PyTorch conv1d function output
conv1d_pytorch = F.conv1d(a, b, groups=sum(layer.enable_lora))

# custom conv1d
conv1d_custom = layer.conv1d(a, b)

# custom conv1d forced to split, apply and concat tensors
layer.n_head = layer.n_query_groups + 1
conv1d_custom_forced = layer.conv1d(a, b)

assert torch.allclose(conv1d_pytorch, conv1d_custom)
assert torch.allclose(conv1d_pytorch, conv1d_custom_forced)


@pytest.mark.parametrize(("rank", "expected_merged"), ((-1, False), (0, False), (1, True)))
def test_lora_linear_weights_merged_status(rank, expected_merged):
from lit_gpt.lora import LoRALinear
Expand Down

0 comments on commit ca21fc3

Please sign in to comment.