Skip to content

Commit

Permalink
Enhanced codebase clarity and adhered to best practices by significan…
Browse files Browse the repository at this point in the history
…tly reducing the reliance on plain strings. Implemented data type indicators in functions handling audio data, incorporated assertions for audio data types, and transformed ClientState into an enum for improved representation. Whisper model sizes are now exclusively utilized as enum elements. Additionally, introduced a Whisper model cache to enable swift transcription startup, eliminating the requirement for model reloading. Fixed bug where the client's Whisper model size list didn't match the server's supported model sizes list. Fixed bug where "No more clients allowed" error doesn't appear correctly in client.
  • Loading branch information
ethanzrd committed Aug 15, 2023
1 parent 33534eb commit c825d5c
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 128 deletions.
21 changes: 9 additions & 12 deletions backend/client_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import asyncio
import threading
from clients.utils import initialize_client
from clients.utils import get_client_class
from utils import cleanup
from config import ClientState


class ClientManager:
Expand All @@ -11,8 +12,9 @@ def __init__(self):
self.clients = {}

async def create_new_client(self, sid, sio, config):
self.clients[sid] = "initializing"
new_client = initialize_client(sid, sio, config)
new_client = get_client_class(config)(sid, sio, config)
self.clients[sid] = new_client
new_client.initialize_client()
if self.clients.get(sid):
self.clients[sid] = new_client
await new_client.start_transcribing()
Expand All @@ -23,9 +25,6 @@ async def start_stream(self, sid, sio, config):
if not self.clients:
if sid not in self.clients.keys():
threading.Thread(target=asyncio.run, args=(self.create_new_client(sid, sio, config),)).start()
else:
logging.warning("A streaming client tried to initiate another stream")
await sio.emit("clientAlreadyStreaming")
else:
logging.warning("A new client tried to start streaming when there is already a client streaming")
await sio.emit("noMoreClientsAllowed")
Expand All @@ -48,12 +47,10 @@ async def end_stream(self, sid):
def disconnect_from_stream(self, sid):
if sid in self.clients.keys():
client = self.clients[sid]
cleanup_needed = False
if client != "initializing":
client.handle_disconnection()
cleanup_needed = client.cleanup_needed
# No error if client is still not an object, it won't get to that point
if client == "initializing" or not client.is_ending_stream():
client.handle_disconnection()
cleanup_needed = client.cleanup_needed
client_state = client.get_state()
if client_state == ClientState.NOT_INITIALIZED or not client_state == ClientState.ENDING_STREAM:
try:
self.clients.pop(sid)
except KeyError:
Expand Down
44 changes: 32 additions & 12 deletions backend/clients/Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,60 @@
from queue import Queue
from silero_vad import silero_vad
from diart.utils import decode_audio
from utils import get_transcriber_information
from transcription.whisper_transcriber import WhisperTranscriber
from abc import abstractmethod
from config import ClientState


class Client:

def __init__(self, sid, socket, transcriber, transcription_timeout):
def __init__(self, sid, socket, config):
self.sid = sid
self.config = config
self.diarization_pipeline = None
self.transcriber = transcriber
self.transcription_timeout = transcription_timeout
self.transcriber = None
self.transcription_timeout = None
self.socket = socket
self.audio_chunks = Queue()
self.transcription_thread = None
self.disconnected = False
self.ending_stream = False
self.cleanup_needed = False
self.state = ClientState.NOT_INITIALIZED

def initialize_client(self):
whisper_model_size, language_code = get_transcriber_information(self.config)
try:
beam_size = int(self.config.get("beamSize", 1))
except TypeError:
logging.warning(f"Invalid beam size {self.config.get('beamSize')}, defaulting to 1")
beam_size = 1
self.transcriber = WhisperTranscriber(model_size=whisper_model_size, language_code=language_code,
beam_size=beam_size)
self.transcription_timeout = int(self.config.get("transcribeTimeout", 5))
self.state = ClientState.INITIALIZED

@abstractmethod
async def start_transcribing(self):
pass
if self.transcriber is None:
raise ValueError("The transcriber must be defined before using this method")

async def stop_transcribing(self):
self.ending_stream = True
self.state = ClientState.ENDING_STREAM
self.transcription_thread.join()
logging.info("Transcription thread closed due to completion (stream ended)")
await self.socket.emit("whisperingStopped")
logging.info("Stream end signaled to client")

def handle_disconnection(self):
logging.info("Starting disconnection process, no longer sending transcriptions to client")
self.disconnected = True
if not self.ending_stream:
if self.state not in [ClientState.ENDING_STREAM, ClientState.NOT_INITIALIZED]:
self.state = ClientState.DISCONNECTED
self.transcription_thread.join()
logging.info("Transcription thread closed due to disconnection")

async def send_transcription(self, transcription):
logging.info(f"Transcription generated: {transcription}")
if not self.disconnected:
if self.state != ClientState.DISCONNECTED:
await self.socket.emit("transcriptionDataAvailable", transcription)
logging.info("Transcription sent")
else:
Expand All @@ -47,5 +67,5 @@ def handle_chunk(self, chunk):
self.audio_chunks.put(chunk)
logging.debug("Chunk added")

def is_ending_stream(self):
return self.ending_stream
def get_state(self):
return self.state
13 changes: 6 additions & 7 deletions backend/clients/RealTimeClient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from diart import OnlineSpeakerDiarization
from config import DIARIZATION_PIPELINE_CONFIG
from config import DIARIZATION_PIPELINE_CONFIG, ClientState
import asyncio
import diart.operators as dops
import rx.operators as ops
Expand All @@ -13,8 +12,8 @@

class RealTimeClient(Client):

def __init__(self, sid, socket, transcriber, transcription_timeout):
super().__init__(sid, socket, transcriber, transcription_timeout)
def __init__(self, sid, socket, config):
super().__init__(sid, socket, config)
self.pipeline_config = DIARIZATION_PIPELINE_CONFIG
self.diarization_pipeline = OnlineSpeakerDiarization(self.pipeline_config)
self.chunk_receiving_thread = None
Expand All @@ -28,7 +27,7 @@ async def start_transcribing(self):
await self.socket.emit("whisperingStarted")
logging.info("Stream start signaled to client")

def receive_chunk(self, chunk):
def receive_chunk(self, chunk: str):
self.source.receive_chunk(chunk)

def complete_stream(self):
Expand All @@ -38,7 +37,7 @@ def complete_stream(self):
def receive_chunks(self):
logging.info("New chunks handler started")
while True:
if self.disconnected:
if self.state == ClientState.DISCONNECTED:
logging.info("Client disconnected, ending transcription...")
self.complete_stream()
return
Expand All @@ -47,7 +46,7 @@ def receive_chunks(self):
# not a heavy operation but a blocking one during pipeline execution, shouldn't block the main thread thanks to threading
self.source.receive_chunk(current_chunk)
else:
if self.ending_stream:
if self.state == ClientState.ENDING_STREAM:
logging.info("No more chunks, preparing for a final transcription...")
self.complete_stream()
return
Expand Down
34 changes: 18 additions & 16 deletions backend/clients/SequentialClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from diart.utils import decode_audio
from utils import save_batch_to_wav
import numpy as np
from config import STEP, TEMP_FILE_PATH
from config import STEP, TEMP_FILE_PATH, REQUIRED_AUDIO_TYPE, ClientState
from pyannote.audio import Pipeline
from clients.Client import Client


class SequentialClient(Client):

def __init__(self, sid, socket, transcriber, transcription_timeout):
super().__init__(sid, socket, transcriber, transcription_timeout)
def __init__(self, sid, socket, config):
super().__init__(sid, socket, config)
self.diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization")
self.cleanup_needed = True

Expand All @@ -22,22 +22,24 @@ async def start_transcribing(self):
await self.socket.emit("whisperingStarted")
logging.info("Stream start signaled to client")

def get_diarization(self, audio):
save_batch_to_wav(audio, TEMP_FILE_PATH)
def get_diarization(self, buffer: np.ndarray):
assert buffer.dtype == REQUIRED_AUDIO_TYPE, f"audio array data type must be {REQUIRED_AUDIO_TYPE}"
save_batch_to_wav(buffer, TEMP_FILE_PATH)
diarization = self.diarization_pipeline(TEMP_FILE_PATH)
return diarization

def transcribe_buffer(self, audio):
diarization = self.get_diarization(audio)
result = self.transcriber.sequential_transcription(audio, diarization)
def transcribe_buffer(self, buffer: np.ndarray):
assert buffer.dtype == REQUIRED_AUDIO_TYPE, f"audio array data type must be {REQUIRED_AUDIO_TYPE}"
diarization = self.get_diarization(buffer)
result = self.transcriber.sequential_transcription(buffer, diarization)
asyncio.run(self.send_transcription(result))

@staticmethod
def convert_buffer_to_audio(buffer):
def convert_buffer_to_float32(buffer: np.ndarray):
return buffer.astype("float32").reshape(-1)

@staticmethod
def modify_buffer(chunk, buffer):
def modify_buffer(chunk: str, buffer: np.ndarray):
decoded_chunk = decode_audio(chunk)
buffer = decoded_chunk if buffer is None else np.concatenate([buffer, decoded_chunk], axis=1)
return buffer
Expand All @@ -50,13 +52,13 @@ def stream_sequential_transcription(self):
assert batch_size > 0, "batch size must be above 0"

while True:
if self.disconnected:
if self.state == ClientState.DISCONNECTED:
logging.info("Client disconnected, ending transcription...")
break
if not self.ending_stream:
if not self.state == ClientState.ENDING_STREAM:
if chunk_counter >= batch_size:
buffer_audio = self.convert_buffer_to_audio(buffer)
self.transcribe_buffer(buffer_audio)
buffer_float32 = self.convert_buffer_to_float32(buffer)
self.transcribe_buffer(buffer_float32)
chunk_counter = 0

if not self.audio_chunks.empty():
Expand All @@ -70,6 +72,6 @@ def stream_sequential_transcription(self):
buffer = self.modify_buffer(current_chunk, buffer)
chunk_counter += 1
if chunk_counter > 0:
buffer_audio = self.convert_buffer_to_audio(buffer)
self.transcribe_buffer(buffer_audio)
buffer_float32 = self.convert_buffer_to_float32(buffer)
self.transcribe_buffer(buffer_float32)
break
42 changes: 3 additions & 39 deletions backend/clients/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
from clients.SequentialClient import SequentialClient
from enum import Enum

from transcription.whisper_transcriber import WhisperTranscriber
from utils import format_whisper_model_name
from config import WhisperModelSize, LANGUAGE_MAPPING


class TranscriptionMethod(Enum):
REAL_TIME = RealTimeClient
Expand All @@ -20,40 +16,8 @@ def format_transcription_method_name(transcription_method):
def get_client_class(config):
transcription_method_name = format_transcription_method_name(config.get("transcriptionMethod"))
try:
client_class = getattr(TranscriptionMethod, transcription_method_name).value
client_class = getattr(TranscriptionMethod, transcription_method_name)
except AttributeError:
logging.warning(f"Invalid transcription method {transcription_method_name}, defaulting to sequential.")
client_class = TranscriptionMethod.SEQUENTIAL.value
return client_class


def get_whisper_model_name(config):
# Format the name received from the client to match the enum members
whisper_model_name = format_whisper_model_name(config.get("model", "small"))
try:
# Retrieve the corresponding enum member
whisper_model = getattr(WhisperModelSize, whisper_model_name)
except AttributeError:
logging.warning(f"Invalid model size {whisper_model_name}, defaulting to small")
whisper_model = WhisperModelSize.SMALL
language = config.get("language", "english")
try:
language_code = LANGUAGE_MAPPING[language.lower()]
except KeyError:
logging.warning(f"Language {language} not supported, defaulting to English")
language_code = "en"
return whisper_model, language_code


def initialize_client(sid, socket, config):
client_class = get_client_class(config)
whisper_model, language_code = get_whisper_model_name(config)
try:
beam_size = int(config.get("beamSize", 1))
except TypeError:
logging.warning(f"Invalid beam size {config.get('beamSize')}, defaulting to 1")
beam_size = 1
transcriber = WhisperTranscriber(model_name=whisper_model.value, language_code=language_code, beam_size=beam_size)
transcription_timeout = int(config.get("transcribeTimeout", 5))
new_client = client_class(sid, socket, transcriber, transcription_timeout)
return new_client
client_class = TranscriptionMethod.SEQUENTIAL
return client_class.value
17 changes: 16 additions & 1 deletion backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,34 @@
SPEECH_CONFIDENCE_THRESHOLD = 0.3 # The minimal amount of confidence to determine speech presence in batch (e.g. 0.5 means 50% chance at minimum)

SAMPLE_RATE = 16000
NON_ENGLISH_SPECIFIC_MODELS = ["large", "large-v1", "large-v2"] # Models that don't have an English-only version

TEMP_FILE_PATH = "temp/batch.wav" # Path to the temporary file used for batch transcription in SequentialClient


class ClientState(Enum):
NOT_INITIALIZED = "not_initialized"
INITIALIZED = "initialized"
ENDING_STREAM = "ending_stream"
DISCONNECTED = "disconnected"


class WhisperModelSize(Enum):
TINY = 'tiny'
TINY_ENGLISH = 'tiny.en'
BASE = 'base'
BASE_ENGLISH = 'base.en'
SMALL = 'small'
SMALL_ENGLISH = 'small.en'
MEDIUM = 'medium'
MEDIUM_ENGLISH = 'medium.en'
LARGE_V1 = 'large-v1'
LARGE_V2 = 'large-v2'


NON_ENGLISH_SPECIFIC_MODELS = [WhisperModelSize.LARGE_V1, WhisperModelSize.LARGE_V2] # Models that don't have an English-only version

REQUIRED_AUDIO_TYPE = "float32"

# Language code mapping
LANGUAGE_MAPPING = {
"english": "en",
Expand Down
6 changes: 4 additions & 2 deletions backend/silero_vad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from config import SPEECH_CONFIDENCE_THRESHOLD
from config import SPEECH_CONFIDENCE_THRESHOLD, REQUIRED_AUDIO_TYPE
import numpy as np


class SileroVAD:
Expand All @@ -13,7 +14,8 @@ def __init__(self):
(self.get_speech_timestamps, self.save_audio, self.read_audio, self.VADIterator,
self.collect_chunks) = self.utils

def __call__(self, audio):
def __call__(self, audio: np.ndarray):
assert audio.dtype == REQUIRED_AUDIO_TYPE, f"audio array data type must be {REQUIRED_AUDIO_TYPE}"
confidence = self.model(torch.from_numpy(audio), 16000).item()
return confidence >= SPEECH_CONFIDENCE_THRESHOLD, confidence

Expand Down
25 changes: 25 additions & 0 deletions backend/transcription/model_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from config import WhisperModelSize
from faster_whisper import WhisperModel
import logging


class ModelCache:
_downloaded_models = {}

@classmethod
def add_downloaded_model(cls, model_size: WhisperModelSize, model: WhisperModel):
cls._downloaded_models[model_size] = model
logging.info(f"{model_size} added to cache")

@classmethod
def is_model_downloaded(cls, model_size: WhisperModelSize):
return model_size in cls._downloaded_models.keys()

@classmethod
def get_model(cls, model_size: WhisperModelSize):
try:
model = cls._downloaded_models[model_size]
logging.info(f"{model_size} retrieved from cache")
return model
except KeyError:
return None
Loading

0 comments on commit c825d5c

Please sign in to comment.