Skip to content

Commit

Permalink
Merge pull request AIGC-Audio#21 from Rongjiehuang/main
Browse files Browse the repository at this point in the history
clean some codes
  • Loading branch information
Rongjiehuang authored Apr 12, 2023
2 parents cb62a28 + aab80e0 commit 8975378
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 69 deletions.
15 changes: 2 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

## Capabilities

Up-to-date link: https://cdb7b543afd1c8e8.gradio.app

Here we list the capability of AudioGPT at this time. More supported models and tasks are comming soon. For prompt examples, refer to [asset](assets/README.md).

### Speech
Expand All @@ -18,8 +16,8 @@ Here we list the capability of AudioGPT at this time. More supported models and
| Text-to-Speech | [FastSpeech](), [SyntaSpeech](), [VITS]() | Yes (WIP) |
| Style Transfer | [GenerSpeech]() | Yes |
| Speech Recognition | [whisper](), [Conformer]() | Yes |
| Speech Enhancement | [ConvTasNet]() | WIP |
| Speech Separation | [TF-GridNet]() | WIP |
| Speech Enhancement | [ConvTasNet]() | Yes (WIP) |
| Speech Separation | [TF-GridNet]() | Yes (WIP) |
| Speech Translation | [Multi-decoder]() | WIP |
| Mono-to-Binaural | [NeuralWarp]() | Yes |

Expand All @@ -46,15 +44,6 @@ Here we list the capability of AudioGPT at this time. More supported models and
|:-------------------------:|:-------------------------------:|:----------:|
| Talking Head Synthesis | [GeneFace]() | Yes (WIP) |

## Internal Version Updates
4.6 Support Sound Extraction/Detection\
4.3 Support huggingface demo space\
4.1 Support Audio inpainting and clean codes\
3.27 Support Style Transfer/Talking head Synthesis\
3.23 Support Text-to-Sing\
3.21 Support Image-to-Audio\
3.19 Support Speech Recognition\
3.17 Support Text-to-Audio

## Todo
- [x] clean text to sing/speech code
Expand Down
130 changes: 74 additions & 56 deletions audio-chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import gradio as gr
import matplotlib
import librosa
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from diffusers import StableDiffusionPipeline
from langchain.agents.initialize import initialize_agent
from langchain.agents.tools import Tool
from langchain.chains.conversation.memory import ConversationBufferMemory
Expand All @@ -24,32 +22,18 @@
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration
from einops import repeat
from ldm.util import instantiate_from_config
from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000
from vocoder.bigvgan.models import VocoderBigVGAN
from ldm.models.diffusion.ddim import DDIMSampler
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
from audio_to_text.inference_waveform import AudioCapModel
import whisper
from inference.svs.ds_e2e import DiffSingerE2EInfer
from inference.tts.GenerSpeech import GenerSpeechInfer
from inference.tts.PortaSpeech import TTSInference
from utils.hparams import set_hparams
from utils.hparams import hparams as hp
import scipy.io.wavfile as wavfile
import librosa
from audio_infer.utils import config as detection_config
from audio_infer.pytorch.models import PVT
from src.models import BinauralNetwork
from sound_extraction.model.LASSNet import LASSNet
from sound_extraction.utils.stft import STFT
from sound_extraction.utils.wav_io import load_wav, save_wav
from target_sound_detection.src import models as tsd_models
from target_sound_detection.src.models import event_labels
from target_sound_detection.src.utils import median_filter, decode_with_timestamps
from espnet2.bin.svs_inference import SingingGenerate
import clip
import numpy as np
AUDIO_CHATGPT_PREFIX = """AudioGPT
Expand Down Expand Up @@ -107,41 +91,6 @@ def cut_dialogue_history(history_memory, keep_last_n_words = 500):
return '\n' + '\n'.join(paragraphs)



def initialize_model(config, ckpt, device):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)

model = model.to(device)
model.cond_stage_model.to(model.device)
model.cond_stage_model.device = model.device
sampler = DDIMSampler(model)
return sampler

def initialize_model_inpaint(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
print(model.device,device,model.cond_stage_model.device)
sampler = DDIMSampler(model)
return sampler

def select_best_audio(prompt,wav_list):
clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth','useful_ckpts/CLAP/config.yml',use_cuda=torch.cuda.is_available())
text_embeddings = clap_model.get_text_embeddings([prompt])
score_list = []
for data in wav_list:
sr,wav = data
audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav),sr)], resample=True)
score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False).squeeze().cpu().numpy()
score_list.append(score)
max_index = np.array(score_list).argmax()
print(score_list,max_index)
return wav_list[max_index]

def merge_audio(audio_path_1, audio_path_2):
merged_signal = []
sr_1, signal_1 = wavfile.read(audio_path_1)
Expand All @@ -156,6 +105,9 @@ def merge_audio(audio_path_1, audio_path_2):

class T2I:
def __init__(self, device):
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
from transformers import pipeline
print("Initializing T2I to %s" % device)
self.device = device
self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
Expand All @@ -175,6 +127,7 @@ def inference(self, text):

class ImageCaptioning:
def __init__(self, device):
from transformers import BlipProcessor, BlipForConditionalGeneration
print("Initializing ImageCaptioning to %s" % device)
self.device = device
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
Expand All @@ -190,9 +143,20 @@ class T2A:
def __init__(self, device):
print("Initializing Make-An-Audio to %s" % device)
self.device = device
self.sampler = initialize_model('text_to_audio/Make_An_Audio/configs/text_to_audio/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio/useful_ckpts/ta40multi_epoch=000085.ckpt', device=device)
self.sampler = self._initialize_model('text_to_audio/Make_An_Audio/configs/text_to_audio/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio/useful_ckpts/ta40multi_epoch=000085.ckpt', device=device)
self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w',device=device)

def _initialize_model(self, config, ckpt, device):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)

model = model.to(device)
model.cond_stage_model.to(model.device)
model.cond_stage_model.device = model.device
sampler = DDIMSampler(model)
return sampler

def txt2audio(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
SAMPLE_RATE = 16000
prng = np.random.RandomState(seed)
Expand All @@ -217,9 +181,25 @@ def txt2audio(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples =
for idx,spec in enumerate(x_samples_ddim):
wav = self.vocoder.vocode(spec)
wav_list.append((SAMPLE_RATE,wav))
best_wav = select_best_audio(text, wav_list)
best_wav = self.select_best_audio(text, wav_list)
return best_wav

def select_best_audio(self, prompt, wav_list):
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
clap_model = CLAPWrapper('useful_ckpts/CLAP/CLAP_weights_2022.pth', 'useful_ckpts/CLAP/config.yml',
use_cuda=torch.cuda.is_available())
text_embeddings = clap_model.get_text_embeddings([prompt])
score_list = []
for data in wav_list:
sr, wav = data
audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav), sr)], resample=True)
score = clap_model.compute_similarity(audio_embeddings, text_embeddings,
use_logit_scale=False).squeeze().cpu().numpy()
score_list.append(score)
max_index = np.array(score_list).argmax()
print(score_list, max_index)
return wav_list[max_index]

def inference(self, text, seed = 55, scale = 1.5, ddim_steps = 100, n_samples = 3, W = 624, H = 80):
melbins,mel_len = 80,624
with torch.no_grad():
Expand All @@ -237,8 +217,20 @@ class I2A:
def __init__(self, device):
print("Initializing Make-An-Audio-Image to %s" % device)
self.device = device
self.sampler = initialize_model('text_to_audio/Make_An_Audio_img/configs/img_to_audio/img2audio_args.yaml', 'text_to_audio/Make_An_Audio_img/useful_ckpts/ta54_epoch=000216.ckpt', device=device)
self.sampler = self._initialize_model('text_to_audio/Make_An_Audio_img/configs/img_to_audio/img2audio_args.yaml', 'text_to_audio/Make_An_Audio_img/useful_ckpts/ta54_epoch=000216.ckpt', device=device)
self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio_img/vocoder/logs/bigv16k53w',device=device)

def _initialize_model(self, config, ckpt, device):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)

model = model.to(device)
model.cond_stage_model.to(model.device)
model.cond_stage_model.device = model.device
sampler = DDIMSampler(model)
return sampler

def img2audio(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H = 80):
SAMPLE_RATE = 16000
n_samples = 1 # only support 1 sample
Expand Down Expand Up @@ -284,6 +276,7 @@ def inference(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H =

class TTS:
def __init__(self, device=None):
from inference.tts.PortaSpeech import TTSInference
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Initializing PortaSpeech to %s" % device)
Expand All @@ -306,6 +299,7 @@ def inference(self, text):

class T2S:
def __init__(self, device= None):
from inference.svs.ds_e2e import DiffSingerE2EInfer
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Initializing DiffSinger to %s" % device)
Expand Down Expand Up @@ -348,6 +342,7 @@ def inference(self, inputs):

class t2s_VISinger:
def __init__(self, device=None):
from espnet2.bin.svs_inference import SingingGenerate
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Initializing VISingere to %s" % device)
Expand Down Expand Up @@ -389,6 +384,7 @@ def inference(self, inputs):

class TTS_OOD:
def __init__(self, device):
from inference.tts.GenerSpeech import GenerSpeechInfer
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Initializing GenerSpeech to %s" % device)
Expand Down Expand Up @@ -425,9 +421,20 @@ class Inpaint:
def __init__(self, device):
print("Initializing Make-An-Audio-inpaint to %s" % device)
self.device = device
self.sampler = initialize_model_inpaint('text_to_audio/Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
self.sampler = self._initialize_model_inpaint('text_to_audio/Make_An_Audio_inpaint/configs/inpaint/txt2audio_args.yaml', 'text_to_audio/Make_An_Audio_inpaint/useful_ckpts/inpaint7_epoch00047.ckpt')
self.vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device)
self.cmap_transform = matplotlib.cm.viridis

def _initialize_model_inpaint(self, config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
print(model.device, device, model.cond_stage_model.device)
sampler = DDIMSampler(model)
return sampler

def make_batch_sd(self, mel, mask, num_samples=1):

mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32)
Expand Down Expand Up @@ -572,6 +579,7 @@ def translate_english(self, audio_path):

class A2T:
def __init__(self, device):
from audio_to_text.inference_waveform import AudioCapModel
print("Initializing Audio-To-Text Model to %s" % device)
self.device = device
self.model = AudioCapModel("audio_to_text/audiocaps_cntrstv_cnn14rnn_trm")
Expand Down Expand Up @@ -668,17 +676,21 @@ def inference(self, audio_path):

class SoundExtraction:
def __init__(self, device):
from sound_extraction.model.LASSNet import LASSNet
from sound_extraction.utils.stft import STFT
import torch.nn as nn
self.device = device
self.model_file = 'sound_extraction/useful_ckpts/LASSNet.pt'
self.stft = STFT()
import torch.nn as nn

self.model = nn.DataParallel(LASSNet(device)).to(device)
checkpoint = torch.load(self.model_file)
self.model.load_state_dict(checkpoint['model'])
self.model.eval()

def inference(self, inputs):
#key = ['ref_audio', 'text']
from sound_extraction.utils.wav_io import load_wav, save_wav
val = inputs.split(",")
audio_path = val[0] # audio_path, text
text = val[1]
Expand All @@ -702,6 +714,7 @@ def inference(self, inputs):

class Binaural:
def __init__(self, device):
from src.models import BinauralNetwork
self.device = device
self.model_file = 'mono2binaural/useful_ckpts/m2b/binaural_network.net'
self.position_file = ['mono2binaural/useful_ckpts/m2b/tx_positions.txt',
Expand Down Expand Up @@ -763,6 +776,9 @@ def inference(self, audio_path):

class TargetSoundDetection:
def __init__(self, device):
from target_sound_detection.src import models as tsd_models
from target_sound_detection.src.models import event_labels

self.device = device
self.MEL_ARGS = {
'n_mels': 64,
Expand Down Expand Up @@ -817,6 +833,8 @@ def cal_similarity(self, target, retrievals):
return ans.index(max(ans))

def inference(self, text, audio_path):
from target_sound_detection.src.utils import median_filter, decode_with_timestamps

target_emb = self.build_clip(text) # torch type
idx = self.cal_similarity(target_emb, self.re_embeds)
target_event = self.id_to_event[idx]
Expand Down

0 comments on commit 8975378

Please sign in to comment.