Skip to content

Commit

Permalink
added --api flag
Browse files Browse the repository at this point in the history
  • Loading branch information
SevaSk committed May 30, 2023
1 parent e0c72ff commit c0b43f3
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
10 changes: 3 additions & 7 deletions AudioTranscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
MAX_PHRASES = 10

class AudioTranscriber:
def __init__(self, mic_source, speaker_source):
def __init__(self, mic_source, speaker_source, model):
self.transcript_data = {"You": [], "Speaker": []}
self.transcript_changed_event = threading.Event()
self.audio_model = whisper.load_model(os.path.join(os.getcwd(), 'tiny.en.pt'))
self.audio_model = model
self.audio_sources = {
"You": {
"sample_rate": mic_source.SAMPLE_RATE,
Expand Down Expand Up @@ -46,7 +46,7 @@ def transcribe_audio_queue(self, audio_queue):
self.update_last_sample_and_phrase_status(who_spoke, data, time_spoken)
source_info = self.audio_sources[who_spoke]
temp_file = source_info["process_data_func"](source_info["last_sample"])
text = self.get_transcription(temp_file)
text = self.audio_model.get_transcription(temp_file)

if text != '' and text.lower() != 'you':
self.update_transcript(who_spoke, text, time_spoken)
Expand Down Expand Up @@ -81,10 +81,6 @@ def process_speaker_data(self, data):
wf.writeframes(data)
return temp_file

def get_transcription(self, file_path):
result = self.audio_model.transcribe(file_path, fp16=torch.cuda.is_available())
return result['text'].strip()

def update_transcript(self, who_spoke, text, time_spoken):
source_info = self.audio_sources[who_spoke]
transcript = self.transcript_data[who_spoke]
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,15 @@ Run the main script:
python main.py
```

Now, Ecoute will start transcribing your microphone input and speaker output in real-time, and provide a suggested response based on the conversation. It may take a couple of seconds to warm up before the transcription becomes real-time.
For a better and faster version, use:

```
python main.py --api
```

Upon initiation, Ecoute will begin transcribing your microphone input and speaker output in real-time, generating a suggested response based on the conversation. Please note that it might take a few seconds for the system to warm up before the transcription becomes real-time.

The --api flag significantly enhances transcription speed and accuracy, and it's expected to be the default option in future releases. However, keep in mind that using the Whisper API will consume more OpenAI credits than using the local model. This increased cost is attributed to the advanced features and capabilities that the Whisper API provides. Despite the additional cost, the considerable improvements in speed and transcription accuracy might make it a worthwhile investment for your use case.

### ⚠️ Limitations

Expand Down
34 changes: 34 additions & 0 deletions TranscriberModels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import openai
import whisper
import os
import torch

def get_model(use_api):
if use_api:
return APIWhisperTranscriber()
else:
return WhisperTranscriber()

class WhisperTranscriber:
def __init__(self):
self.audio_model = whisper.load_model(os.path.join(os.getcwd(), 'tiny.en.pt'))
print(f"[INFO] Whisper using GPU: " + str(torch.cuda.is_available()))

def get_transcription(self, wav_file_path):
try:
result = self.audio_model.transcribe(wav_file_path, fp16=torch.cuda.is_available())
except Exception as e:
print(e)
return result['text'].strip()

class APIWhisperTranscriber:
def get_transcription(self, wav_file_path):
new_file_path = wav_file_path + '.wav'
os.rename(wav_file_path, new_file_path)
audio_file= open(new_file_path, "rb")
try:
result = openai.Audio.translate("whisper-1", audio_file)
except Exception as e:
print(e)

return result['text'].strip()
16 changes: 9 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import queue
import time
import torch
import sys
import TranscriberModels

def write_in_textbox(textbox, text):
textbox.delete("0.0", "end")
Expand Down Expand Up @@ -76,15 +78,15 @@ def main():
speaker_audio_recorder = AudioRecorder.DefaultSpeakerRecorder()
speaker_audio_recorder.record_into_queue(audio_queue)

global_transcriber = AudioTranscriber(user_audio_recorder.source, speaker_audio_recorder.source)
transcribe = threading.Thread(target=global_transcriber.transcribe_audio_queue, args=(audio_queue,))
model = TranscriberModels.get_model('--api' in sys.argv)

transcriber = AudioTranscriber(user_audio_recorder.source, speaker_audio_recorder.source, model)
transcribe = threading.Thread(target=transcriber.transcribe_audio_queue, args=(audio_queue,))
transcribe.daemon = True
transcribe.start()

print(f"[INFO] Whisper using GPU: " + str(torch.cuda.is_available()))

responder = GPTResponder()
respond = threading.Thread(target=responder.respond_to_transcriber, args=(global_transcriber,))
respond = threading.Thread(target=responder.respond_to_transcriber, args=(transcriber,))
respond.daemon = True
respond.start()

Expand All @@ -98,7 +100,7 @@ def main():
root.grid_columnconfigure(1, weight=1)

# Add the clear transcript button to the UI
clear_transcript_button = ctk.CTkButton(root, text="Clear Transcript", command=lambda: clear_context(global_transcriber, audio_queue, ))
clear_transcript_button = ctk.CTkButton(root, text="Clear Transcript", command=lambda: clear_context(transcriber, audio_queue, ))
clear_transcript_button.grid(row=1, column=0, padx=10, pady=3, sticky="nsew")

freeze_state = [False] # Using list to be able to change its content inside inner functions
Expand All @@ -110,7 +112,7 @@ def freeze_unfreeze():

update_interval_slider_label.configure(text=f"Update interval: {update_interval_slider.get()} seconds")

update_transcript_UI(global_transcriber, transcript_textbox)
update_transcript_UI(transcriber, transcript_textbox)
update_response_UI(responder, response_textbox, update_interval_slider_label, update_interval_slider, freeze_state)

root.mainloop()
Expand Down

0 comments on commit c0b43f3

Please sign in to comment.