Skip to content

Commit

Permalink
optimized pipeline.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bhargavshirin committed Oct 24, 2023
1 parent 1af820d commit d89993f
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions pipeline.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
import yaml
from typing import Dict, List
from typing import List
import torch
import torch.nn as nn
import numpy as np
import librosa
from scipy.io.wavfile import write
from utils import ignore_warnings; ignore_warnings()
from utils import parse_yaml, load_ss_model
from utils import ignore_warnings, parse_yaml, load_ss_model
from models.clap_encoder import CLAP_Encoder


def build_audiosep(config_yaml, checkpoint_path, device):
ignore_warnings()
configs = parse_yaml(config_yaml)

query_encoder = CLAP_Encoder().eval()
model = load_ss_model(
configs=configs,
checkpoint_path=checkpoint_path,
query_encoder=query_encoder
).eval().to(device)
model = load_ss_model(configs=configs, checkpoint_path=checkpoint_path, query_encoder=query_encoder).eval().to(device)

print(f'Load AudioSep model from [{checkpoint_path}]')
print(f'Loaded AudioSep model from [{checkpoint_path}]')
return model


def inference(model, audio_file, text, output_file, device='cuda', use_chunk=False):
print(f'Separate audio from [{audio_file}] with textual query [{text}]')
def separate_audio(model, audio_file, text, output_file, device='cuda', use_chunk=False):
print(f'Separating audio from [{audio_file}] with textual query: [{text}]')
mixture, fs = librosa.load(audio_file, sr=32000, mono=True)
with torch.no_grad():
text = [text]
Expand All @@ -49,8 +42,7 @@ def inference(model, audio_file, text, output_file, device='cuda', use_chunk=Fal
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()

write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16))
print(f'Write separated audio to [{output_file}]')

print(f'Separated audio written to [{output_file}]')

if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -61,12 +53,6 @@ def inference(model, audio_file, text, output_file, device='cuda', use_chunk=Fal

audio_file = '/mnt/bn/data-xubo/project/AudioShop/YT_audios/Y3VHpLxtd498.wav'
text = 'pigeons are cooing in the background'
output_file='separated_audio.wav'

inference(model, audio_file, text, output_file, device)

output_file = 'separated_audio.wav'





separate_audio(model, audio_file, text, output_file, device)

0 comments on commit d89993f

Please sign in to comment.