forked from superjcd/sentimentclassification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi.py
39 lines (31 loc) · 1.13 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import argparse
import model
from flask import Flask, request, make_response, jsonify
from model import textcnn, birnn
from dataset import TEXT, LABEL
from utils import transform_data
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
parser = argparse.ArgumentParser()
parser.add_argument('--model-name', default='birnn',choices=['textcnn', 'birnn'], help='choose one model name for trainng')
parser.add_argument('-lmd', '--load-model-dir', default= None, help='path for loadding model, default:None' )
args = parser.parse_args()
# 获取模型名称
net = getattr(model, args.model_name)()
net.load_state_dict(torch.load(args.load_model_dir))
# net = birnn()
# net.load_state_dict(torch.load('models_storage/model_birnn.pt'))
@app.route('/sentiment')
def sentiemnt():
sentence = request.args.get('sentence')
record = {'data':sentence}
data, _ = transform_data(record, TEXT, LABEL)
prediction = net(data).argmax(dim=1).item()
if prediction==0:
result = '积极'
else:
result = '消极'
return jsonify({'data':result, 'status_code':200})
if __name__ == '__main__':
app.run(debug=False)