Skip to content

Commit

Permalink
chores (RVC-Boss#1528)
Browse files Browse the repository at this point in the history
* chores

* ...

* Add files via upload

* ...

* remove gradio warnings

* Update inference_webui.py

Fix inference_cli issue
  • Loading branch information
XXXXRT666 authored Aug 23, 2024
1 parent 2a9512a commit 7dac47c
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 6 deletions.
10 changes: 8 additions & 2 deletions GPT_SoVITS/TTS_infer_pack/TTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ def __init__(self, configs: Union[dict, str]=None):


def _load_configs(self, configs_path: str)->dict:
if os.path.exists(configs_path):
...
else:
print(i18n("路径不存在,使用默认配置"))
self.save_configs(configs_path)
with open(configs_path, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)

Expand Down Expand Up @@ -748,7 +753,8 @@ def run(self, inputs:dict):
phones, bert_features, norm_text = \
self.text_preprocessor.segment_and_extract_feature_for_text(
prompt_text,
prompt_lang)
prompt_lang,
self.configs.version)
self.prompt_cache["phones"] = phones
self.prompt_cache["bert_features"] = bert_features
self.prompt_cache["norm_text"] = norm_text
Expand All @@ -760,7 +766,7 @@ def run(self, inputs:dict):
t1 = ttime()
data:list = None
if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
Expand Down
6 changes: 3 additions & 3 deletions GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def __init__(self, bert_model:AutoModelForMaskedLM,
self.tokenizer = tokenizer
self.device = device

def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v1")->List[Dict]:
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
print(i18n("############ 切分文本 ############"))
text = self.replace_consecutive_punctuation(text) # 变量命名应该是写错了
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(i18n("############ 提取文本Bert特征 ############"))
Expand Down Expand Up @@ -204,7 +204,7 @@ def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T

def clean_text_inf(self, text:str, language:str, version:str="v1"):
def clean_text_inf(self, text:str, language:str, version:str="v2"):
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
Expand Down
1 change: 1 addition & 0 deletions GPT_SoVITS/configs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.yaml
8 changes: 7 additions & 1 deletion GPT_SoVITS/inference_webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import pdb
import torch

try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...

version=os.environ.get("version","v2")
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
Expand Down Expand Up @@ -392,7 +397,8 @@ def merge_short_text_in_array(texts, threshold):
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
# cache_tokens={}#暂未实现清理机制
cache= {}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=123):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free
=False,speed=1,if_freeze=False,inp_refs=None):
global cache
if ref_wav_path:pass
else:gr.Warning(i18n('请上传参考音频'))
Expand Down
5 changes: 5 additions & 0 deletions GPT_SoVITS/inference_webui_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
import pdb
import torch

try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...


infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
Expand Down
5 changes: 5 additions & 0 deletions tools/subfix_webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import os
import uuid

try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...

import librosa
import gradio as gr
import numpy as np
Expand Down
5 changes: 5 additions & 0 deletions tools/uvr5/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from vr import AudioPre, AudioPreDeEcho
from bsroformer import BsRoformer_Loader

try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...

weight_uvr5_root = "tools/uvr5/uvr5_weights"
uvr5_names = []
for name in os.listdir(weight_uvr5_root):
Expand Down

0 comments on commit 7dac47c

Please sign in to comment.