From df1c2570861ce0b1b5b93366474d920d2678aae8 Mon Sep 17 00:00:00 2001 From: Minh Date: Sat, 11 May 2024 04:54:26 +0700 Subject: [PATCH] feat(speech): enable local inference instead of calling api --- app/.env.example | 3 ++- app/app.py | 2 +- app/inference.py | 4 ++-- app/local_inference.py | 18 ++++++++++++++++++ 4 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 app/local_inference.py diff --git a/app/.env.example b/app/.env.example index ff318e6..2d1e293 100644 --- a/app/.env.example +++ b/app/.env.example @@ -4,4 +4,5 @@ MQTT_PORT=1883 ADAFRUIT_USER=tuankiet0303 ADAFRUIT_KEY= MONGO_USER=root -MONGO_PASSWORD=example \ No newline at end of file +MONGO_PASSWORD=example +HUGGINGFACE_KEY= \ No newline at end of file diff --git a/app/app.py b/app/app.py index c65926f..49ec01c 100644 --- a/app/app.py +++ b/app/app.py @@ -11,7 +11,7 @@ from flask import Flask, Response, render_template, request from flask_cors import CORS from flask_socketio import SocketIO, emit -from inference import query +from local_inference import query from myMqtt import * from pymongo import MongoClient from werkzeug.utils import secure_filename diff --git a/app/inference.py b/app/inference.py index c40c155..af06981 100644 --- a/app/inference.py +++ b/app/inference.py @@ -1,8 +1,8 @@ import asyncio - +import os from aiohttp import ClientSession -headers = {"Authorization": f"Bearer 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"} +headers = {"Authorization": f"Bearer {os.environ.get("HUGGINGFACE_KEY")}"} API_URL = "https://api-inference.huggingface.co/models/vinai/PhoWhisper-base" diff --git a/app/local_inference.py b/app/local_inference.py new file mode 100644 index 0000000..bb7b674 --- /dev/null +++ b/app/local_inference.py @@ -0,0 +1,18 @@ +import torch +from transformers import pipeline +import asyncio + +device = "cuda:0" if torch.cuda.is_available() else "cpu" + +model_id = "vinai/PhoWhisper-base" + +transcriber = pipeline("automatic-speech-recognition", model="vinai/PhoWhisper-base",device=device) +async def query(filename): + return transcriber(filename) + +async def main(): + output = await query("uploads/audio.flac") + print(output) + +if __name__ == "__main__": + asyncio.run(main())