forked from THUDM/VisualGLM-6B
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5d368f6
commit 1e7910a
Showing
1 changed file
with
257 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
""" | ||
In this mixin, I use a different implementation than lora.py | ||
I just use a fake linear layer to replace any model with lora mixin. | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from sat.model.base_model import BaseMixin | ||
import math | ||
from sat.helpers import print_all | ||
from sat.model.transformer import RowParallelLinear, ColumnParallelLinear | ||
|
||
class HackLinear(nn.Linear): | ||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | ||
if prefix + 'weight' in state_dict: | ||
self.weight.data.copy_(state_dict[prefix+'weight']) | ||
if prefix + 'bias' in state_dict: | ||
self.bias.data.copy_(state_dict[prefix+'bias']) | ||
|
||
try: | ||
from bitsandbytes.nn import LinearNF4 | ||
def copy_nested_list(src, dst): | ||
for i in range(len(dst)): | ||
if type(dst[i]) is torch.Tensor: | ||
dst[i].copy_(src[i]) | ||
elif type(dst[i]) is list: | ||
copy_nested_list(src[i], dst[i]) | ||
else: | ||
dst[i] = src[i] | ||
class HackLinearNF4(LinearNF4): | ||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | ||
if prefix + 'weight' in state_dict: | ||
self.weight.data.copy_(state_dict[prefix+'weight']) | ||
if self.weight.data.dtype == torch.uint8: | ||
copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state) | ||
if prefix + 'bias' in state_dict: | ||
self.bias.data.copy_(state_dict[prefix+'bias']) | ||
def _save_to_state_dict(self, destination, prefix, keep_vars): | ||
super()._save_to_state_dict(destination, prefix, keep_vars) | ||
destination[prefix+'quant_state'] = self.weight.quant_state | ||
except Exception as exception: | ||
print_all("Failed to load bitsandbytes:" + str(exception), level='WARNING') | ||
|
||
|
||
class HackParameterList(nn.ParameterList): | ||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | ||
for i in range(len(self)): | ||
if prefix + str(i) in state_dict: | ||
self[i].data.copy_(state_dict[prefix+str(i)]) | ||
|
||
class LoraLinear(nn.Module): | ||
def __init__(self, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., qlora=False): | ||
super().__init__() | ||
if lora_dropout and lora_dropout > 0: | ||
self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
else: | ||
self.lora_dropout = lambda x: x | ||
self.r = r | ||
self.lora_alpha = lora_alpha | ||
self.scaling = self.lora_alpha / self.r | ||
if qlora: | ||
self.original = HackLinearNF4(in_dim, out_dim) | ||
else: | ||
self.original = HackLinear(in_dim, out_dim) | ||
self.matrix_A = nn.Parameter(torch.empty((r, in_dim))) | ||
self.matrix_B = nn.Parameter(torch.empty((out_dim, r))) | ||
nn.init.kaiming_uniform_(self.matrix_A, a=math.sqrt(5)) | ||
nn.init.zeros_(self.matrix_B) | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | ||
# This is not a perfect version, becuase it doesn't handle errors and unexpected keys. | ||
if prefix + 'weight' in state_dict: | ||
# load from normal Linear | ||
self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | ||
else: | ||
# load from LoraLinear | ||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | ||
|
||
def forward(self, x): | ||
return self.original(x) + (self.lora_dropout(x) @ self.matrix_A.T @ self.matrix_B.T) * self.scaling | ||
|
||
|
||
class LoraQKV(nn.Module): | ||
def __init__(self, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., head_first=False, num_attention_heads=None, hidden_size_per_attention_head=None, qlora=False): | ||
""" | ||
You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order. | ||
If you use a different order like ChatGLM | ||
""" | ||
super().__init__() | ||
if lora_dropout and lora_dropout > 0: | ||
self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
else: | ||
self.lora_dropout = lambda x: x | ||
self.r = r | ||
self.lora_alpha = lora_alpha | ||
self.scaling = self.lora_alpha / self.r | ||
if qlora: | ||
self.original = HackLinearNF4(in_dim, out_dim) | ||
else: | ||
self.original = HackLinear(in_dim, out_dim) | ||
self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, in_dim))) for _ in range(3)]) | ||
self.matrix_B = HackParameterList([nn.Parameter(torch.empty((out_dim // 3, r))) for _ in range(3)]) | ||
for i in range(3): | ||
nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5)) | ||
nn.init.zeros_(self.matrix_B[i]) | ||
self.head_first = head_first | ||
if head_first: | ||
assert num_attention_heads is not None and hidden_size_per_attention_head is not None, "You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!" | ||
self.num_attention_heads = num_attention_heads | ||
self.hidden_size_per_attention_head = hidden_size_per_attention_head | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | ||
# This is not a perfect version, becuase it doesn't handle errors and unexpected keys. | ||
if prefix + 'weight' in state_dict: | ||
# load from normal Linear | ||
self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | ||
else: | ||
# load from LoraLinear | ||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | ||
|
||
def forward(self, x): | ||
mixed_raw_layer = self.original(x) | ||
lora_outputs = [] | ||
for i in range(3): | ||
lora_outputs.append((self.lora_dropout(x) @ self.matrix_A[i].T @ self.matrix_B[i].T) * self.scaling) | ||
if self.head_first: | ||
new_tensor_shape = lora_outputs[0].size()[:-1] + ( | ||
self.num_attention_heads, | ||
self.hidden_size_per_attention_head, | ||
) | ||
for i in range(3): | ||
lora_outputs[i] = lora_outputs[i].view(*new_tensor_shape) | ||
mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1).view(*mixed_raw_layer.size()) | ||
else: | ||
mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1) | ||
|
||
return mixed_raw_layer | ||
|
||
|
||
def replace_linear_with_lora(lin, base_cls, r, *args, **kw_args): | ||
# not supported for linear without bias for now | ||
out_dim, in_dim = lin.weight.shape | ||
return base_cls(in_dim, out_dim, r, *args, **kw_args) | ||
|
||
def merge_linear_lora(lin): | ||
out_dim, in_dim = lin.original.weight.shape | ||
new_lin = nn.Linear(in_dim, out_dim) | ||
new_lin.bias.data = lin.original.bias.data | ||
new_lin.weight.data = lin.original.weight.data + (lin.matrix_A.data.T.float() @ lin.matrix_B.data.T.float() * lin.scaling).T.to(lin.original.weight.data.dtype) | ||
return new_lin | ||
|
||
def merge_qkv_lora(lin): | ||
out_dim, in_dim = lin.original.weight.shape | ||
new_lin = nn.Linear(in_dim, out_dim) | ||
new_lin.bias.data = lin.original.bias.data | ||
new_qkv = [] | ||
for i in range(3): | ||
new_qkv.append(lin.matrix_A[i].data.T.float() @ lin.matrix_B[i].data.T.float() * lin.scaling) | ||
if lin.head_first: | ||
ini_shape = new_qkv[0].shape | ||
new_qkv = [x.view(ini_shape[0], lin.num_attention_heads, -1) for x in new_qkv] | ||
new_qkv = torch.cat(new_qkv, -1).view(ini_shape[0], 3*ini_shape[1]) | ||
else: | ||
new_qkv = torch.cat(new_qkv, -1) | ||
new_lin.weight.data = lin.original.weight.data + new_qkv.T.to(lin.original.weight.data.dtype) | ||
return new_lin | ||
|
||
class LoraMixin(BaseMixin): | ||
def __init__(self, | ||
layer_num, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0., | ||
layer_range = None, | ||
head_first = False, | ||
num_attention_heads = None, | ||
hidden_size_per_attention_head = None, | ||
qlora = False): | ||
super().__init__() | ||
self.r = r | ||
self.lora_alpha = lora_alpha | ||
self.lora_dropout = lora_dropout | ||
|
||
if layer_range is None: | ||
layer_range = [i for i in range(layer_num)] | ||
self.layer_range = layer_range | ||
|
||
self.scaling = self.lora_alpha / self.r | ||
self.head_first = head_first | ||
self.num_attention_heads = num_attention_heads | ||
self.hidden_size_per_attention_head = hidden_size_per_attention_head | ||
self.qlora = qlora | ||
|
||
def reinit(self, parent_model): | ||
""" | ||
only support self-attention part | ||
not supported for cross-attention for now | ||
""" | ||
for i in self.layer_range: | ||
print(f'replacing layer {i} with lora') | ||
parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].attention.dense, LoraLinear, self.r, self.lora_alpha, self.lora_dropout, self.qlora) | ||
parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(parent_model.transformer.layers[i].attention.query_key_value, LoraQKV, self.r, self.lora_alpha, self.lora_dropout, head_first=self.head_first, num_attention_heads=self.num_attention_heads, hidden_size_per_attention_head=self.hidden_size_per_attention_head, qlora=self.qlora) | ||
if self.qlora: | ||
print('replacing chatglm linear layer with 4bit') | ||
def replace_linear_with_nf4(model, name=None, cache={}): | ||
if type(model) in (nn.Linear, RowParallelLinear, ColumnParallelLinear): | ||
out_dim, in_dim = model.weight.shape | ||
return HackLinearNF4(in_dim, out_dim) | ||
names = set() | ||
for name, child in model.named_children(): | ||
if name not in names: | ||
if child in cache: | ||
new_child = cache[child] | ||
else: | ||
new_child = replace_linear_with_nf4(child, name=name, cache=cache) | ||
cache[child] = new_child | ||
setattr(model, name, new_child) | ||
names.add(name) | ||
flag = True | ||
while flag: | ||
flag = False | ||
for name, child in model.named_children(): | ||
if name not in names: | ||
setattr(model, name, cache[child]) | ||
names.add(name) | ||
flag = True | ||
return model | ||
replace_linear_with_nf4(parent_model.transformer, None, {}) | ||
|
||
def merge_lora(self): | ||
for i in self.layer_range: | ||
print(f'merge layer {i} lora back to linear') | ||
self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense) | ||
self.transformer.layers[i].attention.query_key_value = merge_qkv_lora(self.transformer.layers[i].attention.query_key_value) | ||
|
||
if __name__ == '__main__': | ||
class Model(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.child = nn.Linear(100, 200) | ||
|
||
def forward(self, x): | ||
return self.child(x) | ||
|
||
model = Model() | ||
torch.save(model.state_dict(), "linear.pt") | ||
x = torch.randn(2, 100) | ||
out1 = model(x) | ||
model.child = LoraLinear(100, 200, 10) | ||
model.load_state_dict(torch.load("linear.pt"), strict=False) | ||
out2 = model(x) | ||
torch.save(model.state_dict(), "lora.pt") | ||
ckpt = torch.load("lora.pt") | ||
breakpoint() | ||
model.load_state_dict(ckpt, strict=False) | ||
out3 = model(x) | ||
breakpoint() |