Skip to content

Commit

Permalink
feat: split tagger and locale into different routes
Browse files Browse the repository at this point in the history
  • Loading branch information
bigint committed Jun 20, 2023
1 parent 1175c72 commit 4c64d5c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 74 deletions.
81 changes: 7 additions & 74 deletions packages/ai/app.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,19 @@
import torch
from flask import Flask, jsonify, request
from scipy.special import expit
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Initialize the tagger model
topic_hf = "models/tagger"
topic_tokenizer = AutoTokenizer.from_pretrained(topic_hf)
topic_model = AutoModelForSequenceClassification.from_pretrained(topic_hf)
topic_class_mapping = topic_model.config.id2label

# Initialize the locale model
locale_hf = "models/locale_detector"
locale_tokenizer = AutoTokenizer.from_pretrained(locale_hf)
locale_model = AutoModelForSequenceClassification.from_pretrained(locale_hf)
locale_class_mapping = locale_model.config.id2label
from flask import Flask
from locale_route import locale_bp
from tagger_route import tagger_bp


app = Flask(__name__)

# Register the blueprints
app.register_blueprint(locale_bp)
app.register_blueprint(tagger_bp)


@app.route("/")
def index():
return "Welcome to Lenster AI ✨"


# Health check
@app.route("/ping", methods=["GET"])
def ping():
return jsonify({"ping": "pong"})


def predictTopic(text):
tokens = topic_tokenizer(text, return_tensors="pt")
output = topic_model(**tokens)
scores = expit(output.logits.detach().numpy())
topics = [topic_class_mapping[i] for i in range(len(scores[0]))]
topic_scores = [
{"topic": topic, "score": float(score)}
for topic, score in zip(topics, scores[0])
]
topic_scores = sorted(topic_scores, key=lambda x: x["score"], reverse=True)

# Extract the top two topics
top_topics = [topic_score["topic"] for topic_score in topic_scores[:2]]

return top_topics


# Extract topic from the text
@app.route("/tagger", methods=["POST"])
def tagger():
data = request.get_json()
text = data["text"]

if len(text.split()) > 4:
topic_scores = predictTopic(text)
return jsonify({"topics": topic_scores})
else:
return jsonify({"topics": None})


def predictLocale(text):
tokens = locale_tokenizer(text, return_tensors="pt")
output = locale_model(**tokens)
predictions = torch.nn.functional.softmax(output.logits, dim=-1)
_, preds = torch.max(predictions, dim=-1)

return locale_class_mapping[preds.item()]


# Extract locale from the text
@app.route("/locale", methods=["POST"])
def locale():
data = request.get_json()
text = data["text"]

if len(text.split()) > 4:
locale_scores = predictLocale(text)
return jsonify({"locale": locale_scores})
else:
return jsonify({"locale": None})


if __name__ == "__main__":
app.run(debug=False, host="0.0.0.0", port=8000)
33 changes: 33 additions & 0 deletions packages/ai/locale_route.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from flask import Blueprint, jsonify, request
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Create a Blueprint instance
locale_bp = Blueprint("locale", __name__)

# Initialize the locale model
locale_hf = "models/locale_detector"
locale_tokenizer = AutoTokenizer.from_pretrained(locale_hf)
locale_model = AutoModelForSequenceClassification.from_pretrained(locale_hf)
locale_class_mapping = locale_model.config.id2label


def predictLocale(text):
tokens = locale_tokenizer(text, return_tensors="pt")
output = locale_model(**tokens)
predictions = torch.nn.functional.softmax(output.logits, dim=-1)
_, preds = torch.max(predictions, dim=-1)

return locale_class_mapping[preds.item()]


@locale_bp.route("/locale", methods=["POST"])
def locale():
data = request.get_json()
text = data["text"]

if len(text.split()) > 4:
locale_scores = predictLocale(text)
return jsonify({"locale": locale_scores})
else:
return jsonify({"locale": None})
41 changes: 41 additions & 0 deletions packages/ai/tagger_route.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from flask import Blueprint, jsonify, request
from scipy.special import expit
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Create a Blueprint instance
tagger_bp = Blueprint("tagger", __name__)

# Initialize the tagger model
topic_hf = "models/tagger"
topic_tokenizer = AutoTokenizer.from_pretrained(topic_hf)
topic_model = AutoModelForSequenceClassification.from_pretrained(topic_hf)
topic_class_mapping = topic_model.config.id2label


def predictTopic(text):
tokens = topic_tokenizer(text, return_tensors="pt")
output = topic_model(**tokens)
scores = expit(output.logits.detach().numpy())
topics = [topic_class_mapping[i] for i in range(len(scores[0]))]
topic_scores = [
{"topic": topic, "score": float(score)}
for topic, score in zip(topics, scores[0])
]
topic_scores = sorted(topic_scores, key=lambda x: x["score"], reverse=True)

# Extract the top two topics
top_topics = [topic_score["topic"] for topic_score in topic_scores[:2]]

return top_topics


@tagger_bp.route("/tagger", methods=["POST"])
def tagger():
data = request.get_json()
text = data["text"]

if len(text.split()) > 4:
topic_scores = predictTopic(text)
return jsonify({"topics": topic_scores})
else:
return jsonify({"topics": None})

0 comments on commit 4c64d5c

Please sign in to comment.