Skip to content

Commit

Permalink
add enh / ss
Browse files Browse the repository at this point in the history
  • Loading branch information
simpleoier committed Apr 11, 2023
1 parent e2b06d3 commit 181bcee
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
80 changes: 80 additions & 0 deletions audio-chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,86 @@ def inference(self, text, audio_path):
#print(ans)
return ans

class Speech_Enh_SS_SC:
"""Speech Enhancement or Separation in single-channel
Example usage:
enh_model = Speech_Enh_SS("cuda")
enh_wav = enh_model.inference("./test_chime4_audio_M05_440C0213_PED_REAL.wav")
"""
def __init__(self, device="cuda", model_name="lichenda/chime4_fasnet_dprnn_tac"):
self.model_name = model_name
self.device = device
print("Initializing ESPnet Enh to %s" % device)
self._initialize_model()

def _initialize_model(self):
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.enh_inference import SeparateSpeech

d = ModelDownloader()

cfg = d.download_and_unpack(self.model_name)
self.separate_speech = SeparateSpeech(
train_config=cfg["train_config"],
model_file=cfg["model_file"],
# for segment-wise process on long speech
segment_size=2.4,
hop_size=0.8,
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=None,
normalize_output_wav=True,
device=self.device,
)

def inference(self, speech_path, ref_channel=0):
speech, sr = soundfile.read(speech_path)
speech = speech[:, ref_channel]
assert speech.dim() == 1

enh_speech = self.separate_speech(speech[None, ], fs=sr)
if len(enh_speech) == 1:
return enh_speech[0]
return enh_speech

class Speech_Enh_SS_MC:
"""Speech Enhancement or Separation in multi-channel"""
def __init__(self, device="cuda", model_name=None, ref_channel=4):
self.model_name = model_name
self.ref_channel = ref_channel
self.device = device
print("Initializing ESPnet Enh to %s" % device)
self._initialize_model()

def _initialize_model(self):
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.enh_inference import SeparateSpeech

d = ModelDownloader()

cfg = d.download_and_unpack(self.model_name)
self.separate_speech = SeparateSpeech(
train_config=cfg["train_config"],
model_file=cfg["model_file"],
# for segment-wise process on long speech
segment_size=2.4,
hop_size=0.8,
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=self.ref_channel,
normalize_output_wav=True,
device=self.device,
)

def inference(self, speech_path):
speech, sr = soundfile.read(speech_path)
speech = speech.T

enh_speech = self.separate_speech(speech[None, ...], fs=sr)
if len(enh_speech) == 1:
return enh_speech[0]
return enh_speech

class ConversationBot:
def __init__(self):
print("Initializing AudioGPT")
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ beautifulsoup4==4.10.0
Cython==0.29.24
diffusers
einops==0.3.0
espnet
espnet_model_zoo
g2p-en==2.1.0
google==3.0.0
gradio
Expand Down

0 comments on commit 181bcee

Please sign in to comment.