forked from SevaSk/ecoute
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 5b86c1e
Showing
9 changed files
with
315 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__/ | ||
*.wav | ||
keys.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import pyaudio | ||
import queue | ||
|
||
def get_device_list(): | ||
devices = [] | ||
p = pyaudio.PyAudio() | ||
info = p.get_host_api_info_by_index(0) | ||
numdevices = info.get('deviceCount') | ||
for i in range(0, numdevices): | ||
if (p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: | ||
devices.append(p.get_device_info_by_host_api_device_index(0, i).get('name')) | ||
if (p.get_device_info_by_host_api_device_index(0, i).get('maxOutputChannels')) > 0: | ||
devices.append(p.get_device_info_by_host_api_device_index(0, i).get('name')) | ||
return devices | ||
|
||
class AudioStream(object): | ||
"""Opens a recording stream as a generator yielding the audio chunks.""" | ||
|
||
def __init__(self, rate, input_device_index): | ||
self._rate = rate | ||
self._chunk = int(rate / 10) # 100ms for 16000Hz | ||
self.input_device_index = input_device_index | ||
# Create a thread-safe buffer of audio data | ||
self._buff = queue.Queue() | ||
self.closed = True | ||
|
||
def __enter__(self): | ||
self._audio_interface = pyaudio.PyAudio() | ||
self._audio_stream = self._audio_interface.open( | ||
format=pyaudio.paInt16, | ||
# The API currently only supports 1-channel (mono) audio | ||
# https://goo.gl/z757pE | ||
channels=1, | ||
rate=self._rate, | ||
input=True, | ||
frames_per_buffer=self._chunk, | ||
# Run the audio stream asynchronously to fill the buffer object. | ||
# This is necessary so that the input device's buffer doesn't | ||
# overflow while the calling thread makes network requests, etc. | ||
stream_callback=self._fill_buffer, | ||
input_device_index=self.input_device_index, | ||
) | ||
|
||
self.closed = False | ||
|
||
return self | ||
|
||
def __exit__(self, type, value, traceback): | ||
self._audio_stream.stop_stream() | ||
self._audio_stream.close() | ||
self.closed = True | ||
# Signal the generator to terminate so that the client's | ||
# streaming_recognize method will not block the process termination. | ||
self._buff.put(None) | ||
self._audio_interface.terminate() | ||
|
||
def _fill_buffer(self, in_data, frame_count, time_info, status_flags): | ||
"""Continuously collect data from the audio stream, into the buffer.""" | ||
self._buff.put(in_data) | ||
return None, pyaudio.paContinue | ||
|
||
def generator(self): | ||
while not self.closed: | ||
# Use a blocking get() to ensure there's at least one chunk of | ||
# data, and stop iteration if the chunk is None, indicating the | ||
# end of the audio stream. | ||
chunk = self._buff.get() | ||
if chunk is None: | ||
return | ||
data = [chunk] | ||
|
||
# Now consume whatever other data's still buffered. | ||
while True: | ||
try: | ||
chunk = self._buff.get(block=False) | ||
if chunk is None: | ||
return | ||
data.append(chunk) | ||
except queue.Empty: | ||
break | ||
|
||
yield b"".join(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
class Microphone: | ||
def __init__(self, id: str, loop_back : bool): | ||
self.id = id | ||
self.loop_back = loop_back |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
import soundcard as sc | ||
import threading | ||
import time | ||
import queue | ||
import whisper | ||
import torch | ||
import argparse | ||
import wave | ||
import os | ||
from Microphone import Microphone | ||
|
||
TRANSCRIPT_LIMIT = 10 | ||
RECORDING_TIME = 5 | ||
|
||
class AudioTranscriber: | ||
def __init__(self, lang: str, microphone : Microphone): | ||
self.audio_np_array_queue = queue.Queue() | ||
self.status = 'Running' | ||
self.transcript_data = [] | ||
self.microphone = microphone | ||
self.lang = lang | ||
self.lock = threading.Lock() | ||
self.start_time = time.time() # Record the start time | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", default="tiny", help="Model to use", | ||
choices=["tiny", "base", "small", "medium", "large"]) | ||
parser.add_argument("--non_english", action='store_true', | ||
help="Don't use the english model.") | ||
parser.add_argument("--energy_threshold", default=1000, | ||
help="Energy level for mic to detect.", type=int) | ||
parser.add_argument("--record_timeout", default=2, | ||
help="How real time the recording is in seconds.", type=float) | ||
parser.add_argument("--phrase_timeout", default=3, | ||
help="How much empty space between recordings before we " | ||
"consider it a new line in the transcription.", type=float) | ||
args = parser.parse_args() | ||
# Load / Download model | ||
model = args.model | ||
if args.model != "large" and not args.non_english: | ||
model = model + ".en" | ||
self.audio_model = whisper.load_model(os.getcwd() + r'\tiny.en' + '.pt') | ||
|
||
def get_transcript(self): | ||
return self.transcript_data | ||
|
||
def record_into_queue(self): | ||
SAMPLE_RATE = 16000 | ||
with sc.get_microphone(id=self.microphone.id, include_loopback=self.microphone.loop_back).recorder(samplerate=SAMPLE_RATE) as mic: | ||
while True: | ||
data = mic.record(numframes=SAMPLE_RATE*RECORDING_TIME) # data is a frames x channels Numpy array. | ||
self.audio_np_array_queue.put(data) | ||
return | ||
|
||
def transcribe_from_queue(self): | ||
with self.lock: | ||
while True: | ||
audio_data = self.audio_np_array_queue.get() | ||
with wave.open(f'temp_{self.microphone.id}.wav', 'wb') as wav_file: | ||
wav_file.setnchannels(audio_data.shape[1]) | ||
wav_file.setsampwidth(2) | ||
wav_file.setframerate(16000) | ||
audio_data = (audio_data * (2**15 - 1)).astype(np.int16) | ||
wav_file.writeframes(audio_data.tobytes()) | ||
result = self.audio_model.transcribe(f'temp_{self.microphone.id}.wav', fp16=torch.cuda.is_available()) | ||
text = result['text'].strip() | ||
if text != '' and text.lower() != 'you': # whisper gives "you" on many null inputs | ||
timestamp = int(time.time()) | ||
self.transcript_data.append({'utterance': text, 'timestamp': timestamp}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import openai | ||
from keys import OPENAI_API_KEY | ||
from prompts import create_prompt, INITIAL_RESPONSE | ||
|
||
openai.api_key = OPENAI_API_KEY | ||
|
||
class GPTResponder: | ||
def __init__(self): | ||
self.last_transcript = "" | ||
self.last_response = INITIAL_RESPONSE | ||
|
||
def generate_response_from_transcript(self, transcript): | ||
if transcript == self.last_transcript: | ||
return self.last_response | ||
response = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo-0301", | ||
messages=[{"role": "system", "content": create_prompt(transcript)}], | ||
temperature = 0.0 | ||
) | ||
full_response = response.choices[0].message.content | ||
try: | ||
conversational_response = full_response.split('[')[1].split(']')[0] | ||
except: | ||
return self.last_response | ||
self.last_transcript = transcript | ||
self.last_response = conversational_response | ||
return conversational_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# pyinstaller --onedir --add-data "C:/Users/mcfar/AppData/Local/Programs/Python/Python310/Lib/site-packages/customtkinter;customtkinter/" --noconfirm --windowed --noconsole main.py | ||
|
||
import threading | ||
from audio_transcriber import AudioTranscriber, TRANSCRIPT_LIMIT | ||
from gpt_responder import GPTResponder | ||
import customtkinter as ctk | ||
from Microphone import Microphone | ||
import soundcard as sc | ||
|
||
def write_in_textbox(textbox, text): | ||
textbox.delete("0.0", "end") | ||
textbox.insert("0.0", text) | ||
|
||
#TODO make fast leetcode :) | ||
def create_transcript_string(transcriber_mic, transcriber_speaker, reverse = True): | ||
transcript_string = "" | ||
|
||
mic_transcript = transcriber_mic.get_transcript() | ||
speaker_transcript = transcriber_speaker.get_transcript() | ||
total_transcript = [('You', data) for data in mic_transcript] + [('Speaker', data) for data in speaker_transcript] | ||
sorted_transcript = sorted(total_transcript, key = lambda x: x[1]['timestamp'], reverse = reverse) | ||
for source, line in sorted_transcript[:TRANSCRIPT_LIMIT]: | ||
transcript_string += source + ": [" + line['utterance'] + ']\n\n' | ||
return transcript_string | ||
|
||
def update_transcript_UI(transcriber_mic, transcriber_thread_mic, transcriber_speaker, transcriber_thread_speaker, textbox, mic_transcription_status_label, speaker_transcription_status_label): | ||
transcript_string = create_transcript_string(transcriber_mic, transcriber_speaker, reverse=True) | ||
textbox.delete("0.0", "end") | ||
textbox.insert("0.0", transcript_string) | ||
mic_transcription_status_label.configure(text=f"Mic transcription status: {transcriber_mic.status}") | ||
speaker_transcription_status_label.configure(text=f"Speaker transcription status: {transcriber_speaker.status}") | ||
textbox.after(200, update_transcript_UI, transcriber_mic, transcriber_thread_mic, transcriber_speaker, transcriber_thread_speaker, textbox, mic_transcription_status_label, speaker_transcription_status_label) | ||
|
||
def update_response_UI(transcriber_mic, transcriber_speaker, responder, textbox, update_interval_slider_label, update_interval_slider): | ||
transcript_string = create_transcript_string(transcriber_mic, transcriber_speaker,reverse=False) | ||
t = threading.Thread(target=lambda: responder.generate_response_from_transcript(transcript_string)) | ||
t.start() | ||
textbox.configure(state="normal") | ||
textbox.delete("0.0", "end") | ||
textbox.insert("0.0", responder.last_response) | ||
textbox.configure(state="disabled") | ||
update_interval = int(update_interval_slider.get()) | ||
update_interval_slider_label.configure(text=f"Update interval: {update_interval} seconds") | ||
textbox.after(int(update_interval * 1000), update_response_UI, transcriber_mic, transcriber_speaker, responder, textbox, update_interval_slider_label, update_interval_slider) | ||
|
||
def clear_transcript_data(transcriber_mic, transcriber_speaker): | ||
transcriber_mic.transcript_data.clear() | ||
transcriber_speaker.transcript_data.clear() | ||
|
||
if __name__ == "__main__": | ||
ctk.set_appearance_mode("dark") | ||
ctk.set_default_color_theme("dark-blue") | ||
root = ctk.CTk() | ||
root.title("Ecoute") | ||
root.configure(bg='#252422') | ||
root.geometry("1000x600") | ||
font_size = 20 | ||
|
||
transcript_textbox = ctk.CTkTextbox(root, width=300, font=("Arial", font_size), text_color='#FFFCF2', wrap="word") | ||
transcript_textbox.grid(row=0, column=0, padx=10, pady=20, sticky="nsew") | ||
|
||
response_textbox = ctk.CTkTextbox(root, width=300, font=("Arial", font_size), text_color='#639cdc', wrap="word") | ||
response_textbox.grid(row=0, column=1, padx=10, pady=20, sticky="nsew") | ||
|
||
# Add the clear transcript button to the UI | ||
clear_transcript_button = ctk.CTkButton(root, text="Clear Transcript", command=lambda: clear_transcript_data(transcriber_mic, transcriber_speaker)) | ||
clear_transcript_button.grid(row=1, column=0, padx=10, pady=3, sticky="nsew") | ||
# empty label, necessary for proper grid spacing | ||
update_interval_slider_label = ctk.CTkLabel(root, text=f"", font=("Arial", 12), text_color="#FFFCF2") | ||
update_interval_slider_label.grid(row=1, column=1, padx=10, pady=3, sticky="nsew") | ||
|
||
# Create the update interval slider | ||
update_interval_slider = ctk.CTkSlider(root, from_=1, to=10, width=300, height=20, number_of_steps=9) | ||
update_interval_slider.set(2) | ||
update_interval_slider.grid(row=3, column=1, padx=10, pady=10, sticky="nsew") | ||
update_interval_slider_label = ctk.CTkLabel(root, text=f"Update interval: {update_interval_slider.get()} seconds", font=("Arial", 12), text_color="#FFFCF2") | ||
update_interval_slider_label.grid(row=2, column=1, padx=10, pady=10, sticky="nsew") | ||
|
||
responder = GPTResponder() | ||
|
||
user_mirophone = Microphone(str(sc.default_microphone().name), False) | ||
transcriber_mic = AudioTranscriber(lang='en-US', microphone=user_mirophone) | ||
recorder_thread_mic = threading.Thread(target=transcriber_mic.record_into_queue) | ||
transcriber_thread_mic = threading.Thread(target=transcriber_mic.transcribe_from_queue) | ||
recorder_thread_mic.start() | ||
transcriber_thread_mic.start() | ||
|
||
speaker_mirophone = Microphone(str(sc.default_speaker().name), True) | ||
transcriber_speaker = AudioTranscriber(lang='en-US', microphone=speaker_mirophone) | ||
recorder_thread_speaker = threading.Thread(target=transcriber_speaker.record_into_queue) | ||
transcriber_thread_speaker = threading.Thread(target=transcriber_speaker.transcribe_from_queue) | ||
recorder_thread_speaker.start() | ||
transcriber_thread_speaker.start() | ||
|
||
# Create status label for both transcribers | ||
mic_transcription_status_label = ctk.CTkLabel(root, text=f"Mic transcription status: {transcriber_mic.status}", font=("Arial", 12), text_color="#FFFCF2") | ||
mic_transcription_status_label.grid(row=4, column=0, padx=5, pady=5, sticky="nsew") | ||
speaker_transcription_status_label = ctk.CTkLabel(root, text=f"Speaker transcription status: {transcriber_speaker.status}", font=("Arial", 12), text_color="#FFFCF2") | ||
speaker_transcription_status_label.grid(row=4, column=1, padx=5, pady=5, sticky="nsew") | ||
|
||
root.grid_rowconfigure(0, weight=100) | ||
root.grid_rowconfigure(1, weight=10) | ||
root.grid_rowconfigure(2, weight=1) | ||
root.grid_rowconfigure(3, weight=1) | ||
root.grid_rowconfigure(4, weight=1) | ||
root.grid_columnconfigure(0, weight=2) | ||
root.grid_columnconfigure(1, weight=1) | ||
|
||
update_transcript_UI(transcriber_mic, transcriber_thread_mic, transcriber_speaker, transcriber_thread_speaker, transcript_textbox, mic_transcription_status_label, speaker_transcription_status_label) | ||
update_response_UI(transcriber_mic, transcriber_speaker, responder, response_textbox, update_interval_slider_label, update_interval_slider) | ||
|
||
root.mainloop() | ||
|
||
transcriber_thread_mic.join() | ||
transcriber_thread_speaker.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
INITIAL_RESPONSE = "Welcome to Ecoute 👋" | ||
def create_prompt(transcript): | ||
return f"""You are a casual pal, genuinely interested in the conversation at hand. A poor transcription of conversation is given below. | ||
{transcript}. | ||
Please respond, in detail, to the conversation. Confidently give a straightforward response to the speaker, even if you don't understand them. Give your response in square brackets. DO NOT ask to repeat, and DO NOT ask for clarification. Just answer the speaker directly.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
numpy | ||
soundcard | ||
openai-whisper | ||
torch | ||
wave | ||
pyaudio | ||
openai | ||
customtkinter |
Binary file not shown.