forked from shibing624/text2vec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfastapi_server_demo.py
56 lines (47 loc) · 1.46 KB
/
fastapi_server_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description: pip install fastapi uvicorn
"""
import argparse
import uvicorn
import sys
import os
from fastapi import FastAPI, Query
from starlette.middleware.cors import CORSMiddleware
import torch
from loguru import logger
sys.path.append('..')
from text2vec import SentenceModel
pwd_path = os.path.abspath(os.path.dirname(__file__))
use_cuda = torch.cuda.is_available()
logger.info(f'use_cuda:{use_cuda}')
# Use fine-tuned model
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="shibing624/text2vec-base-chinese",
help="Model save dir or model name")
args = parser.parse_args()
s_model = SentenceModel(args.model_name_or_path)
# define the app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
@app.get('/')
async def index():
return {"message": "index, docs url: /docs"}
@app.get('/emb')
async def emb(q: str = Query(..., min_length=1, max_length=512, title='query')):
try:
embeddings = s_model.encode(q)
result_dict = {'emb': embeddings.tolist()}
logger.debug(f"Successfully get sentence embeddings, q:{q}")
return result_dict
except Exception as e:
logger.error(e)
return {'status': False, 'msg': e}, 400
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=8001)