From cc58be254cb3957ee9c1ec91545c72364cd176d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:48:43 +0900 Subject: [PATCH] fix(webui): crash on non-ffmpeg env. (#466) --- examples/cmd/run.py | 12 +++--------- examples/web/funcs.py | 6 +++--- examples/web/webui.py | 1 - tools/audio/__init__.py | 3 +-- tools/audio/mp3.py | 20 ++++++++++++++++++++ 5 files changed, 27 insertions(+), 15 deletions(-) create mode 100644 tools/audio/mp3.py diff --git a/examples/cmd/run.py b/examples/cmd/run.py index 9579ac507..bbb715c60 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -12,23 +12,17 @@ import ChatTTS -from tools.audio import unsafe_float_to_int16, wav2 +from tools.audio import wav_arr_to_mp3_view from tools.logger import get_logger logger = get_logger("Command") def save_mp3_file(wav, index): - buf = BytesIO() - with wave.open(buf, "wb") as wf: - wf.setnchannels(1) # Mono channel - wf.setsampwidth(2) # Sample width in bytes - wf.setframerate(24000) # Sample rate in Hz - wf.writeframes(unsafe_float_to_int16(wav)) - buf.seek(0, 0) + data = wav_arr_to_mp3_view(wav) mp3_filename = f"output_audio_{index}.mp3" with open(mp3_filename, "wb") as f: - wav2(buf, f, "mp3") + f.write(data) logger.info(f"Audio saved to {mp3_filename}") diff --git a/examples/web/funcs.py b/examples/web/funcs.py index 4e2c2152d..f6de2e131 100644 --- a/examples/web/funcs.py +++ b/examples/web/funcs.py @@ -6,7 +6,7 @@ import gradio as gr import numpy as np -from tools.audio import unsafe_float_to_int16 +from tools.audio import wav_arr_to_mp3_view from tools.logger import get_logger logger = get_logger(" WebUI ") @@ -146,10 +146,10 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream): for gen in wav: audio = gen[0] if audio is not None and len(audio) > 0: - yield 24000, unsafe_float_to_int16(audio[0]) + yield wav_arr_to_mp3_view(audio[0]).tobytes() del audio else: - yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten()) + yield wav_arr_to_mp3_view(np.array(wav[0]).flatten()).tobytes() def interrupt_generate(): diff --git a/examples/web/webui.py b/examples/web/webui.py index 7d2023653..341e3b1a6 100644 --- a/examples/web/webui.py +++ b/examples/web/webui.py @@ -107,7 +107,6 @@ def make_audio(autoplay, stream): streaming=stream, interactive=False, show_label=True, - format="mp3", ) generate_button.click(fn=set_buttons_before_generate, inputs=[generate_button, interrupt_button], outputs=[generate_button, interrupt_button]).then( refine_text, diff --git a/tools/audio/__init__.py b/tools/audio/__init__.py index 14566107f..6d36eae3a 100644 --- a/tools/audio/__init__.py +++ b/tools/audio/__init__.py @@ -1,2 +1 @@ -from .np import unsafe_float_to_int16 -from .av import wav2 +from .mp3 import wav_arr_to_mp3_view diff --git a/tools/audio/mp3.py b/tools/audio/mp3.py new file mode 100644 index 000000000..c6acaf346 --- /dev/null +++ b/tools/audio/mp3.py @@ -0,0 +1,20 @@ +import wave +from io import BytesIO + +import numpy as np + +from .np import unsafe_float_to_int16 +from .av import wav2 + +def wav_arr_to_mp3_view(wav: np.ndarray): + buf = BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) # Mono channel + wf.setsampwidth(2) # Sample width in bytes + wf.setframerate(24000) # Sample rate in Hz + wf.writeframes(unsafe_float_to_int16(wav)) + buf.seek(0, 0) + buf2 = BytesIO() + wav2(buf, buf2, "mp3") + buf.seek(0, 0) + return buf2.getbuffer()