Skip to content

Commit

Permalink
优化tts_config代码逻辑 (RVC-Boss#1538)
Browse files Browse the repository at this point in the history
* 优化tts_config

* fix

* 优化报错提示

* 优化报错提示
  • Loading branch information
ChasonJiang authored Aug 28, 2024
1 parent 7dac47c commit f35f6e9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
11 changes: 8 additions & 3 deletions GPT_SoVITS/TTS_infer_pack/TTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def update_configs(self):
"cnhuhbert_base_path": self.cnhuhbert_base_path,
}
return self.config

def update_version(self, version:str)->None:
self.version = version
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages

def __str__(self):
self.configs = self.update_configs()
Expand Down Expand Up @@ -300,13 +304,14 @@ def init_bert_weights(self, base_path: str):
def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path
self.configs.save_configs()
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.configs.version = "v1"
self.configs.update_version("v1")
else:
self.configs.version = "v2"
self.configs.update_version("v2")
self.configs.save_configs()

hps["model"]["version"] = self.configs.version
self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"]
Expand Down
6 changes: 3 additions & 3 deletions api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,13 @@ def check_params(req:dict):
if (text_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
elif text_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": "text_lang is not supported"})
return JSONResponse(status_code=400, content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"})
if (prompt_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
elif prompt_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"})
return JSONResponse(status_code=400, content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"})
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})

Expand Down

0 comments on commit f35f6e9

Please sign in to comment.