Skip to content

Commit

Permalink
add the selector for the other models in app_onnx.py, close SkyTNT#19
Browse files Browse the repository at this point in the history
  • Loading branch information
SkyTNT committed Oct 6, 2024
1 parent 78c4dab commit 45ac53e
Showing 1 changed file with 78 additions and 6 deletions.
84 changes: 78 additions & 6 deletions app_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os.path
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from sys import exit

import gradio as gr
Expand Down Expand Up @@ -143,9 +144,32 @@ def send_msgs(msgs):
return json.dumps(msgs)


def run(tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig, key_sig, mid,
midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
gen_events, temp, top_p, top_k, allow_cc):
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
global current_model, model_base, model_token, tokenizer
if current_model != model_name:
gr.Info("Loading model...")
model_info = models_info[model_name]
model_config = model_info[0]
model_base_path, model_base_url = model_info[1]
model_token_path, model_token_url = model_info[2]
try:
download_if_not_exit(model_base_url, model_base_path)
download_if_not_exit(model_token_url, model_token_path)
except Exception as e:
print(e)
raise gr.Error("Failed to download files.")
try:
model_base = rt.InferenceSession(model_base_path, providers=providers)
model_token = rt.InferenceSession(model_token_path, providers=providers)
tokenizer = get_tokenizer(model_config)
current_model = model_name
gr.Info("Model loaded")
except Exception as e:
print(e)
raise gr.Error("Failed to load models, maybe you need to delete them and re-download it.")

bpm = int(bpm)
if time_sig == "auto":
time_sig = None
Expand Down Expand Up @@ -326,6 +350,7 @@ def download_if_not_exit(url, output_file):
if os.path.exists(output_file):
return
try:
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
download(url, output_file)
except Exception as e:
print(f"Failed to download {output_file} from {url}")
Expand Down Expand Up @@ -381,8 +406,10 @@ def get_tokenizer(config_name):
parser.add_argument("--max-gen", type=int, default=4096, help="max")
parser.add_argument("--soundfont-path", type=str, default="soundfont.sf2", help="soundfont")
parser.add_argument("--model-config", type=str, default="tv2o-medium", help="model config name")
parser.add_argument("--model-base-path", type=str, default="model_base.onnx", help="model path")
parser.add_argument("--model-token-path", type=str, default="model_token.onnx", help="model path")
parser.add_argument("--model-base-path", type=str,
default="save_models/default/model_base.onnx", help="model path")
parser.add_argument("--model-token-path", type=str,
default="save_models/default/model_token.onnx", help="model path")
parser.add_argument("--soundfont-url", type=str,
default="https://huggingface.co/skytnt/midi-model/resolve/main/soundfont.sf2",
help="download soundfont to soundfont-path if file not exist")
Expand All @@ -394,6 +421,49 @@ def get_tokenizer(config_name):
help="download model-token to model-token-path if file not exist")
opt = parser.parse_args()
OUTPUT_BATCH_SIZE = opt.batch
models_info = {
"generic pretrain model (tv2o-medium) by skytnt (default)": [
opt.model_config,
[opt.model_base_path, opt.model_base_url],
[opt.model_token_path, opt.model_token_url]
],
"generic pretrain model (tv2o-medium) by skytnt with jpop lora": [
"tv2o-medium",
["save_models/tv2om_skytnt_jpop_lora/model_base.onnx",
"https://huggingface.co/skytnt/midi-model-tv2om-jpop-lora/resolve/main/onnx/model_base.onnx"],
["save_models/tv2om_skytnt_jpop_lora/model_token.onnx",
"https://huggingface.co/skytnt/midi-model-tv2om-jpop-lora/resolve/main/onnx/model_token.onnx"]
],
"generic pretrain model (tv2o-medium) by skytnt with touhou lora": [
"tv2o-medium",
["save_models/tv2om_skytnt_touhou_lora/model_base.onnx",
"https://huggingface.co/skytnt/midi-model-tv2om-touhou-lora/resolve/main/onnx/model_base.onnx"],
["save_models/tv2om_skytnt_touhou_lora/model_token.onnx",
"https://huggingface.co/skytnt/midi-model-tv2om-touhou-lora/resolve/main/onnx/model_token.onnx"]
],
"generic pretrain model (tv2o-large) by asigalov61": [
"tv2o-large",
["save_models/tv2ol_asigalov61/model_base.onnx",
"https://huggingface.co/asigalov61/Music-Llama/resolve/main/onnx/model_base.onnx"],
["save_models/tv2ol_asigalov61/model_token.onnx",
"https://huggingface.co/asigalov61/Music-Llama/resolve/main/onnx/model_token.onnx"]
],
"generic pretrain model (tv2o-medium) by asigalov61": [
"tv2o-medium",
["save_models/tv2om_asigalov61/model_base.onnx",
"https://huggingface.co/asigalov61/Music-Llama-Medium/resolve/main/onnx/model_base.onnx"],
["save_models/tv2om_asigalov61/model_token.onnx",
"https://huggingface.co/asigalov61/Music-Llama-Medium/resolve/main/onnx/model_token.onnx"]
],
"generic pretrain model (tv1-medium) by skytnt": [
"tv1-medium",
["save_models/tv1m_skytnt/model_base.onnx",
"https://huggingface.co/skytnt/midi-model/resolve/main/onnx/model_base.onnx"],
["save_models/tv1m_skytnt/model_token.onnx",
"https://huggingface.co/skytnt/midi-model/resolve/main/onnx/model_token.onnx"]
]
}
current_model = list(models_info.keys())[0]
try:
download_if_not_exit(opt.soundfont_url, opt.soundfont_path)
download_if_not_exit(opt.model_base_url, opt.model_base_path)
Expand Down Expand Up @@ -435,6 +505,8 @@ def get_tokenizer(config_name):
return [];
}
""")
input_model = gr.Dropdown(label="select model", choices=list(models_info.keys()),
type="value", value=list(models_info.keys())[0])
tab_select = gr.State(value=0)
with gr.Tabs():
with gr.TabItem("custom prompt") as tab1:
Expand Down Expand Up @@ -519,7 +591,7 @@ def get_tokenizer(config_name):
output_midi = gr.File(label="output midi", file_types=[".mid"])
midi_outputs.append(output_midi)
audio_outputs.append(output_audio)
run_event = run_btn.click(run, [tab_select, output_midi_seq, output_continuation_state,
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
input_continuation_select, input_instruments, input_drum_kit, input_bpm,
input_time_sig, input_key_sig, input_midi, input_midi_events,
input_reduce_cc_st, input_remap_track_channel, input_add_default_instr,
Expand Down

0 comments on commit 45ac53e

Please sign in to comment.