Skip to content

Commit

Permalink
support lora
Browse files Browse the repository at this point in the history
  • Loading branch information
SkyTNT committed Oct 5, 2024
1 parent 7c77774 commit 58fb183
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 47 deletions.
17 changes: 14 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,16 @@ def undo_continuation(mid_seq, continuation_state):
return mid_seq, continuation_state, send_msgs(end_msgs)


def load_model(path, model_config):
def load_model(path, model_config, lora_path):
global model, tokenizer
model = MIDIModel(config=MIDIModelConfig.from_name(model_config))
tokenizer = model.tokenizer
ckpt = torch.load(path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
model.load_state_dict(state_dict, strict=False)
if lora_path:
model.load_adapter(lora_path, "default")
model.set_adapter("default")
model.to(opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
return "success"

Expand All @@ -295,6 +298,11 @@ def get_model_path():
model_paths = sorted(glob.glob("**/*.ckpt", recursive=True))
return gr.Dropdown(choices=model_paths)

def get_lora_path():
lora_paths = sorted(glob.glob("**/adapter_config.json", recursive=True))
lora_paths = [lora_path.replace("adapter_config.json","") for lora_path in lora_paths]
return gr.Dropdown(choices=lora_paths)


def load_javascript(dir="javascript"):
scripts_list = glob.glob(f"{dir}/*.js")
Expand Down Expand Up @@ -353,12 +361,15 @@ def template_response(*args, **kwargs):
with gr.Accordion(label="Model option", open=True):
load_model_path_btn = gr.Button("Get Models")
model_path_input = gr.Dropdown(label="model")
model_config_input = gr.Dropdown(label="config", choices=config_name_list, value=config_name_list[0])
model_config_input = gr.Dropdown(label="config", choices=config_name_list, value="tv2o-medium")
load_model_path_btn.click(get_model_path, [], model_path_input)
load_lora_path_btn = gr.Button("Get Loras")
lora_path_input = gr.Dropdown(label="lora")
load_lora_path_btn.click(get_lora_path, [], lora_path_input)
load_model_btn = gr.Button("Load")
model_msg = gr.Textbox()
load_model_btn.click(
load_model, [model_path_input, model_config_input], model_msg
load_model, [model_path_input, model_config_input, lora_path_input], model_msg
)
tab_select = gr.State(value=0)
with gr.Tabs():
Expand Down
3 changes: 2 additions & 1 deletion midi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tqdm
import lightning as pl
from transformers import LlamaModel, LlamaConfig
from transformers.integrations import PeftAdapterMixin

from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer

Expand Down Expand Up @@ -56,7 +57,7 @@ def from_name(name="tv2o-medium"):
raise ValueError(f"Unknown model size {size}")


class MIDIModel(pl.LightningModule):
class MIDIModel(pl.LightningModule, PeftAdapterMixin):
def __init__(self, config: MIDIModelConfig, flash=False, *args, **kwargs):
super(MIDIModel, self).__init__()
if flash:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Pillow
numpy
torch
safetensors
peft>=0.13.0
transformers>=4.36
lightning==2.4.0
gradio==4.43.0
Expand Down
137 changes: 94 additions & 43 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import argparse
import os
import random
from typing import Union
from pathlib import Path
from typing import Union, Optional

import numpy as np
import lightning as pl
import numpy as np
import torch
import torch.nn.functional as F
from lightning import Trainer
from lightning.fabric.utilities import rank_zero_only
from lightning.pytorch.callbacks import ModelCheckpoint
from peft import LoraConfig, TaskType
from safetensors.torch import save_file as safe_save_file
from torch import optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader

import MIDI
from midi_model import MIDIModel, MIDIModelConfig, config_name_list
from midi_tokenizer import MIDITokenizer, MIDITokenizerV1, MIDITokenizerV2
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2

EXTENSION = [".mid", ".midi"]

Expand All @@ -26,7 +29,8 @@ def file_ext(fname):


class MidiDataset(Dataset):
def __init__(self, midi_list, tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2], max_len=2048, min_file_size=3000, max_file_size=384000,
def __init__(self, midi_list, tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2], max_len=2048, min_file_size=3000,
max_file_size=384000,
aug=True, check_quality=False, rand_start=True):

self.tokenizer = tokenizer
Expand Down Expand Up @@ -73,7 +77,7 @@ def __getitem__(self, index):
start_idx = random.choice([0, start_idx])
else:
max_start = max(1, mid.shape[0] - self.max_len)
start_idx = (index*(max_start//8)) % max_start
start_idx = (index * (max_start // 8)) % max_start
mid = mid[start_idx: start_idx + self.max_len]
mid = mid.astype(np.int64)
mid = torch.from_numpy(mid)
Expand Down Expand Up @@ -103,15 +107,17 @@ def lr_lambda(current_step):
class TrainMIDIModel(MIDIModel):
def __init__(self, config: MIDIModelConfig, flash=False,
lr=2e-4, weight_decay=0.01, warmup=1e3, max_step=1e6, sample_seq=False,
gen_example=True, example_batch=8):
gen_example_interval=1, example_batch=8):
super(TrainMIDIModel, self).__init__(config, flash=flash)
self.lr = lr
self.weight_decay = weight_decay
self.warmup = warmup
self.max_step = max_step
self.sample_seq = sample_seq
self.gen_example = gen_example
self.gen_example_interval = gen_example_interval
self.example_batch = example_batch
self.last_save_step = 0
self.gen_example_count = 0

def configure_optimizers(self):
param_optimizer = list(self.named_parameters())
Expand Down Expand Up @@ -200,41 +206,68 @@ def validation_step(self, batch, batch_idx):
self.log_dict({"val/loss": loss, "val/acc": acc}, sync_dist=True)
return loss

def on_validation_start(self):
torch.cuda.empty_cache()

def on_validation_end(self):
@rank_zero_only
def gen_example():
base_dir = f"sample/{self.global_step}"
if not os.path.exists(base_dir):
os.mkdir(base_dir)
midis = self.generate(batch_size=self.example_batch)
midis = [self.tokenizer.detokenize(midi) for midi in midis]
imgs = [self.tokenizer.midi2img(midi) for midi in midis]
for i, (img, midi) in enumerate(zip(imgs, midis)):
img.save(f"{base_dir}/0_{i}.png")
with open(f"{base_dir}/0_{i}.mid", 'wb') as f:
f.write(MIDI.score2midi(midi))
prompt = val_dataset.load_midi(random.randint(0, len(val_dataset) - 1))
prompt = np.asarray(prompt, dtype=np.int16)
ori = prompt[:512]
img = self.tokenizer.midi2img(self.tokenizer.detokenize(ori))
img.save(f"{base_dir}/1_ori.png")
prompt = prompt[:256].astype(np.int64)
midis = self.generate(prompt, batch_size=self.example_batch)
midis = [self.tokenizer.detokenize(midi) for midi in midis]
imgs = [self.tokenizer.midi2img(midi) for midi in midis]
for i, (img, midi) in enumerate(zip(imgs, midis)):
img.save(f"{base_dir}/1_{i}.png")
with open(f"{base_dir}/1_{i}.mid", 'wb') as f:
f.write(MIDI.score2midi(midi))
if self.gen_example:
@rank_zero_only
def gen_example(self, save_dir):
base_dir = f"{save_dir}/sample/{self.global_step}"
if not os.path.exists(base_dir):
Path(base_dir).mkdir(parents=True)
midis = self.generate(batch_size=self.example_batch)
midis = [self.tokenizer.detokenize(midi) for midi in midis]
imgs = [self.tokenizer.midi2img(midi) for midi in midis]
for i, (img, midi) in enumerate(zip(imgs, midis)):
img.save(f"{base_dir}/0_{i}.png")
with open(f"{base_dir}/0_{i}.mid", 'wb') as f:
f.write(MIDI.score2midi(midi))
prompt = val_dataset.load_midi(random.randint(0, len(val_dataset) - 1))
prompt = np.asarray(prompt, dtype=np.int16)
ori = prompt[:512]
img = self.tokenizer.midi2img(self.tokenizer.detokenize(ori))
img.save(f"{base_dir}/1_ori.png")
prompt = prompt[:256].astype(np.int64)
midis = self.generate(prompt, batch_size=self.example_batch)
midis = [self.tokenizer.detokenize(midi) for midi in midis]
imgs = [self.tokenizer.midi2img(midi) for midi in midis]
for i, (img, midi) in enumerate(zip(imgs, midis)):
img.save(f"{base_dir}/1_{i}.png")
with open(f"{base_dir}/1_{i}.mid", 'wb') as f:
f.write(MIDI.score2midi(midi))

@rank_zero_only
def save_peft(self, save_dir):
adapter_name = self.active_adapters()[0]
adapter_config = self.peft_config[adapter_name]
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
adapter_config.save_pretrained(save_dir)
adapter_state_dict = self.get_adapter_state_dict(adapter_name)
safe_save_file(adapter_state_dict,
os.path.join(save_dir, "adapter_model.safetensors"),
metadata={"format": "pt"})

def on_save_checkpoint(self, checkpoint):
if self.global_step == self.last_save_step:
return
self.last_save_step = self.global_step
trainer = self.trainer
if len(trainer.loggers) > 0:
if trainer.loggers[0].save_dir is not None:
save_dir = trainer.loggers[0].save_dir
else:
save_dir = trainer.default_root_dir
name = trainer.loggers[0].name
version = trainer.loggers[0].version
version = version if isinstance(version, str) else f"version_{version}"
save_dir = os.path.join(save_dir, str(name), version)
else:
save_dir = trainer.default_root_dir
if self._hf_peft_config_loaded:
self.save_peft(os.path.join(save_dir, "lora"))
self.gen_example_count += 1
if self.gen_example_interval>0 and self.gen_example_count % self.gen_example_interval == 0:
try:
gen_example()
self.gen_example(save_dir)
except Exception as e:
print(e)
torch.cuda.empty_cache()


def get_midi_list(path):
Expand All @@ -261,6 +294,9 @@ def get_midi_list(path):
parser.add_argument(
"--config", type=str, default="tv2o-medium", choices=config_name_list, help="model config"
)
parser.add_argument(
"--task", type=str, default="train", choices=["train", "lora"], help="Full train or lora"
)

# dataset args
parser.add_argument(
Expand Down Expand Up @@ -293,7 +329,7 @@ def get_midi_list(path):
"--sample-seq", action="store_true", default=False, help="sample midi seq to reduce vram"
)
parser.add_argument(
"--disable-gen-example", action="store_true", default=False, help="disable generate example on validation end"
"--gen-example-interval", type=int, default=1, help="generate example interval. set 0 to disable"
)
parser.add_argument(
"--batch-size-train", type=int, default=2, help="batch size for training"
Expand Down Expand Up @@ -362,8 +398,10 @@ def get_midi_list(path):
train_dataset_len = full_dataset_len - opt.data_val_split
train_midi_list = midi_list[:train_dataset_len]
val_midi_list = midi_list[train_dataset_len:]
train_dataset = MidiDataset(train_midi_list, tokenizer, max_len=opt.max_len, aug=True, check_quality=opt.quality, rand_start=True)
val_dataset = MidiDataset(val_midi_list, tokenizer, max_len=opt.max_len, aug=False, check_quality=opt.quality, rand_start=False)
train_dataset = MidiDataset(train_midi_list, tokenizer, max_len=opt.max_len, aug=True, check_quality=opt.quality,
rand_start=True)
val_dataset = MidiDataset(val_midi_list, tokenizer, max_len=opt.max_len, aug=False, check_quality=opt.quality,
rand_start=False)
train_dataloader = DataLoader(
train_dataset,
batch_size=opt.batch_size_train,
Expand All @@ -385,12 +423,25 @@ def get_midi_list(path):
print(f"train: {len(train_dataset)} val: {len(val_dataset)}")
model = TrainMIDIModel(config, flash=True, lr=opt.lr, weight_decay=opt.weight_decay,
warmup=opt.warmup_step, max_step=opt.max_step,
sample_seq=opt.sample_seq, gen_example=not opt.disable_gen_example,
sample_seq=opt.sample_seq, gen_example_interval=opt.gen_example_interval,
example_batch=opt.batch_size_gen_example)
if opt.ckpt:
ckpt = torch.load(opt.ckpt, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
model.load_state_dict(state_dict, strict=False)
elif opt.task == "lora":
raise ValueError("--ckpt must be set to train lora")
if opt.task == "lora":
model.requires_grad_(False)
lora_config = LoraConfig(
r=256,
target_modules=["q_proj", "v_proj"],
task_type=TaskType.CAUSAL_LM,
bias="none",
lora_alpha=512,
lora_dropout=0
)
model.add_adapter(lora_config)
print("---start train---")
checkpoint_callback = ModelCheckpoint(
monitor="val/loss",
Expand Down

0 comments on commit 58fb183

Please sign in to comment.