Skip to content

Commit

Permalink
update gradio and midi visvalizer
Browse files Browse the repository at this point in the history
  • Loading branch information
SkyTNT committed Jun 9, 2024
1 parent 499912b commit 000f1f7
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 41 deletions.
27 changes: 13 additions & 14 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
i = 0
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
patches = {}
if instruments is None:
instruments = []
for instr in instruments:
patches[i] = patch2number[instr]
i = (i + 1) if i != 8 else 10
Expand Down Expand Up @@ -154,7 +156,7 @@ def load_model(path):

def get_model_path():
model_paths = sorted(glob.glob("**/*.ckpt", recursive=True))
return model_path_input.update(choices=model_paths)
return gr.Dropdown(choices=model_paths)


def load_javascript(dir="javascript"):
Expand All @@ -174,20 +176,17 @@ def template_response(*args, **kwargs):

gr.routes.templates.TemplateResponse = template_response

# JSMsgReceiver
HTML_postprocess_ori = gr.HTML.postprocess

class JSMsgReceiver(gr.HTML):

def __init__(self, **kwargs):
super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
def JSMsgReceiver_postprocess(self, y):
if self.elem_id == "msg_receiver" and y:
y = f"<p>{json.dumps(y)}</p>"
return HTML_postprocess_ori(self, y)

def postprocess(self, y):
if y:
y = f"<p>{json.dumps(y)}</p>"
return super().postprocess(y)

def get_block_name(self) -> str:
return "html"

gr.HTML.postprocess = JSMsgReceiver_postprocess

number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
40: "Blush", 48: "Orchestra"}
Expand All @@ -206,7 +205,7 @@ def get_block_name(self) -> str:
load_javascript()
app = gr.Blocks()
with app:
js_msg = JSMsgReceiver()
js_msg = gr.HTML(elem_id="msg_receiver", visible=False)
with gr.Accordion(label="Model option", open=False):
load_model_path_btn = gr.Button("Get Models")
model_path_input = gr.Dropdown(label="model")
Expand All @@ -216,7 +215,7 @@ def get_block_name(self) -> str:
load_model_btn.click(
load_model, model_path_input, model_msg
)
tab_select = gr.Variable(value=0)
tab_select = gr.State(value=0)
with gr.Tabs():
with gr.TabItem("instrument prompt") as tab1:
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
Expand Down Expand Up @@ -252,7 +251,7 @@ def get_block_name(self) -> str:
example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
run_btn = gr.Button("generate", variant="primary")
stop_btn = gr.Button("stop and output")
output_midi_seq = gr.Variable()
output_midi_seq = gr.State()
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
output_midi = gr.File(label="output midi", file_types=[".mid"])
Expand Down
25 changes: 13 additions & 12 deletions app_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
i = 0
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
patches = {}
if instruments is None:
instruments = []
for instr in instruments:
patches[i] = patch2number[instr]
i = (i + 1) if i != 8 else 10
Expand Down Expand Up @@ -208,18 +210,17 @@ def template_response(*args, **kwargs):
gr.routes.templates.TemplateResponse = template_response


class JSMsgReceiver(gr.HTML):
# JSMsgReceiver
HTML_postprocess_ori = gr.HTML.postprocess

def __init__(self, **kwargs):
super().__init__(elem_id="msg_receiver", visible=False, **kwargs)

def postprocess(self, y):
if y:
y = f"<p>{json.dumps(y)}</p>"
return super().postprocess(y)
def JSMsgReceiver_postprocess(self, y):
if self.elem_id == "msg_receiver" and y:
y = f"<p>{json.dumps(y)}</p>"
return HTML_postprocess_ori(self, y)

def get_block_name(self) -> str:
return "html"

gr.HTML.postprocess = JSMsgReceiver_postprocess


number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
Expand Down Expand Up @@ -276,8 +277,8 @@ def get_block_name(self) -> str:
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
" for faster running"
)
js_msg = JSMsgReceiver()
tab_select = gr.Variable(value=0)
js_msg = gr.HTML(elem_id="msg_receiver", visible=False)
tab_select = gr.State(value=0)
with gr.Tabs():
with gr.TabItem("instrument prompt") as tab1:
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
Expand Down Expand Up @@ -313,7 +314,7 @@ def get_block_name(self) -> str:
example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
run_btn = gr.Button("generate", variant="primary")
stop_btn = gr.Button("stop and output")
output_midi_seq = gr.Variable()
output_midi_seq = gr.State()
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
output_midi = gr.File(label="output midi", file_types=[".mid"])
Expand Down
95 changes: 81 additions & 14 deletions javascript/app.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,46 @@
/**
* 自动绕过 shadowRoot 的 querySelector
* @param {string} selector - 要查询的 CSS 选择器
* @returns {Element|null} - 匹配的元素或 null 如果未找到
*/
function deepQuerySelector(selector) {
/**
* 在指定的根元素或文档对象下深度查询元素
* @param {Element|Document} root - 要开始搜索的根元素或文档对象
* @param {string} selector - 要查询的 CSS 选择器
* @returns {Element|null} - 匹配的元素或 null 如果未找到
*/
function deepSearch(root, selector) {
// 在当前根元素下查找
let element = root.querySelector(selector);
if (element) {
return element;
}

// 如果未找到,递归检查 shadow DOM
const shadowHosts = root.querySelectorAll('*');

for (let i = 0; i < shadowHosts.length; i++) {
const host = shadowHosts[i];

// 检查当前元素是否有 shadowRoot
if (host.shadowRoot) {
element = deepSearch(host.shadowRoot, selector);
if (element) {
return element;
}
}
}
// 未找到元素
return null;
}

return deepSearch(this, selector);
}

Element.prototype.deepQuerySelector = deepQuerySelector;
Document.prototype.deepQuerySelector = deepQuerySelector;

function gradioApp() {
const elems = document.getElementsByTagName('gradio-app')
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
Expand Down Expand Up @@ -98,6 +141,7 @@ class MidiVisualizer extends HTMLElement{
this.timePreBeat = 16
this.svgWidth = 0;
this.t1 = 0;
this.totalTimeMs = 0
this.playTime = 0
this.playTimeMs = 0
this.colorMap = new Map();
Expand Down Expand Up @@ -137,6 +181,7 @@ class MidiVisualizer extends HTMLElement{
this.t1 = 0
this.colorMap.clear()
this.setPlayTime(0);
this.totalTimeMs = 0;
this.playTimeMs = 0
this.svgWidth = 0
this.svg.innerHTML = ''
Expand Down Expand Up @@ -215,6 +260,9 @@ class MidiVisualizer extends HTMLElement{
tempo = (60 / midiEvent[3]) * 10 ** 3
this.midiTimes.push({ms:ms, t: t, tempo: tempo})
}
if(midiEvent[0]==="note"){
this.totalTimeMs = ms + (midiEvent[3]/ this.timePreBeat)*tempo
}
lastT = t
})
}
Expand Down Expand Up @@ -277,16 +325,10 @@ class MidiVisualizer extends HTMLElement{

play(){
this.playing = true;
this.timer = setInterval(() => {
this.setPlayTimeMs(this.playTimeMs + 10)
}, 10);
}

pause(){
if(!!this.timer)
clearInterval(this.timer)
this.removeActiveNotes(this.activeNotes)
this.timer = null;
this.playing = false;
}

Expand All @@ -299,17 +341,34 @@ class MidiVisualizer extends HTMLElement{
audio.addEventListener("pause", (event)=>{
this.pause()
})
audio.addEventListener("timeupdate", (event)=>{
this.setPlayTimeMs(event.target.currentTime*10**3)
})
}

bindWaveformCursor(cursor){
let self = this;
const callback = function(mutationsList, observer) {
for(let mutation of mutationsList) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
let progress = parseFloat(mutation.target.style.left.slice(0,-1))*0.01;
if(!isNaN(progress)){
self.setPlayTimeMs(progress*self.totalTimeMs);
}
}
}
};
const observer = new MutationObserver(callback);
observer.observe(cursor, {
attributes: true,
attributeFilter: ['style']
});
}
}

customElements.define('midi-visualizer', MidiVisualizer);

(()=>{
let midi_visualizer_container_inited = null
let midi_audio_inited = null;
let midi_audio_audio_inited = null;
let midi_audio_cursor_inited = null;
let midi_visualizer = document.createElement('midi-visualizer')
onUiUpdate((m)=>{
let app = gradioApp()
Expand All @@ -318,10 +377,18 @@ customElements.define('midi-visualizer', MidiVisualizer);
midi_visualizer_container.appendChild(midi_visualizer)
midi_visualizer_container_inited = midi_visualizer_container;
}
let midi_audio = app.querySelector("#midi_audio > audio");
if(!!midi_audio && midi_audio_inited!==midi_audio){
midi_visualizer.bindAudioPlayer(midi_audio)
midi_audio_inited = midi_audio
let midi_audio = app.querySelector("#midi_audio");
if (!!midi_audio){
let midi_audio_cursor = midi_audio.deepQuerySelector(".cursor");
if(!!midi_audio_cursor && midi_audio_cursor_inited!==midi_audio_cursor){
midi_visualizer.bindWaveformCursor(midi_audio_cursor)
midi_audio_cursor_inited = midi_audio_cursor
}
let midi_audio_audio = midi_audio.deepQuerySelector("audio");
if(!!midi_audio_audio && midi_audio_audio_inited!==midi_audio_audio){
midi_visualizer.bindAudioPlayer(midi_audio_audio)
midi_audio_audio_inited = midi_audio_audio
}
}
})

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ torch
transformers>=4.32.1
optimum>=1.12.0
pytorch_lightning
gradio==3.41.2
gradio==4.36.0
pyfluidsynth

0 comments on commit 000f1f7

Please sign in to comment.