Skip to content

Commit

Permalink
fix merge lora bug
Browse files Browse the repository at this point in the history
  • Loading branch information
1049451037 committed May 31, 2023
1 parent 36d72ff commit 3a2d4a6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lora_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def replace_linear_with_lora(lin, partition, r, *args, **kw_args):
return LoraLinear(original_cls, partition, in_dim, out_dim, r, *args, **kw_args)

def merge_linear_lora(lin):
if type(lin.original) is HackLinear:
if lin.original.weight.data.dtype is not torch.uint8:
weight = lin.original.weight
out_dim, in_dim = weight.shape
new_lin = nn.Linear(in_dim, out_dim)
Expand Down Expand Up @@ -230,7 +230,7 @@ def merge_lora(self):
print(f'merge layer {i} lora attention back to linear')
self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense)
self.transformer.layers[i].attention.query_key_value = merge_linear_lora(self.transformer.layers[i].attention.query_key_value)
if parent_model.transformer.layers[i].is_decoder:
if self.transformer.layers[i].is_decoder:
print(f'merge layer {i} lora cross attention back to linear')
self.transformer.layers[i].cross_attention.dense = merge_linear_lora(self.transformer.layers[i].cross_attention.dense)
self.transformer.layers[i].cross_attention.query = merge_linear_lora(self.transformer.layers[i].cross_attention.query)
Expand Down

0 comments on commit 3a2d4a6

Please sign in to comment.