-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
49 lines (35 loc) · 1.17 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
IDE: PyCharm
Project: complete-sentence-prediction
Author: Robin
Filename: server.py
Date: 18.01.2020
"""
import os
from json import dumps
import torch
import uvicorn
from fastapi import FastAPI
from simpletransformers.classification import ClassificationModel
from starlette.responses import Response
from torch.nn.functional import softmax
from static import LABELS
# load latest model
model = ClassificationModel('bert', 'outputs/', num_labels=2)
# initialize web app
app = FastAPI()
@app.get("/api/is_complete")
def read_root(text: str):
predictions, raw_outputs = model.predict([text])
tensor = torch.from_numpy(raw_outputs).float()
probabilities = softmax(tensor, dim=1)
result = []
best = {"label": LABELS[predictions[0]], "confidence": probabilities[0][predictions[0]].item()}
result.append(best)
other_index = 1 if predictions[0] == 0 else 0
other = {"label": LABELS[other_index], "confidence": probabilities[0][other_index].item()}
result.append(other)
return Response(content=dumps(result), media_type='application/json')
if __name__ == "__main__":
port = os.getenv("PORT", 8000)
uvicorn.run(app, host="0.0.0.0", port=port)