Skip to content

Commit

Permalink
merge_lora
Browse files Browse the repository at this point in the history
SkyTNT committed Oct 6, 2024
1 parent af9fbf5 commit dc62d5c
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
![](./banner.png)

## Updates

- v1.3: MIDITokenizerV2 and new MidiVisualizer
- v1.2 : Optimise the tokenizer and dataset. The dataset was filtered by MIDITokenizer.check_quality. Using the higher quality dataset to train the model, the performance of the model is significantly improved.

## Demo
3 changes: 1 addition & 2 deletions app.py
Original file line number Diff line number Diff line change
@@ -290,8 +290,7 @@ def load_model(path, model_config, lora_path):
state_dict = ckpt.get("state_dict", ckpt)
model.load_state_dict(state_dict, strict=False)
if lora_path:
model.load_adapter(lora_path, "default")
model.set_adapter("default")
model = model.load_merge_lora(lora_path)
model.to(opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
return "success"

8 changes: 8 additions & 0 deletions midi_model.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import torch.nn.functional as F
import tqdm
import lightning as pl
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
from transformers import LlamaModel, LlamaConfig
from transformers.integrations import PeftAdapterMixin

@@ -70,6 +71,13 @@ def __init__(self, config: MIDIModelConfig, flash=False, *args, **kwargs):
self.net_token = LlamaModel(config.net_token_config)
self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)

def load_merge_lora(self, model_id):
peft_config = PeftConfig.from_pretrained(model_id)
model = LoraModel(self, peft_config, adapter_name="default")
adapter_state_dict = load_peft_weights(model_id, device=str(self.device))
set_peft_model_state_dict(self, adapter_state_dict, "default")
return model.merge_and_unload()

def forward_token(self, hidden_state, x=None):
"""

0 comments on commit dc62d5c

Please sign in to comment.