Skip to content

Commit

Permalink
add ngram model and gradio.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Feb 16, 2022
1 parent a6129de commit cfeed78
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 10 deletions.
41 changes: 41 additions & 0 deletions examples/gradio_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description: pip install gradio
"""

import gradio as gr
from text2vec import Similarity

# 中文句向量模型(CoSENT)
sim_model = Similarity(model_name_or_path='shibing624/text2vec-base-chinese',
similarity_type='cosine', embedding_type='sbert')


def ai_text(sentence1, sentence2):
score = sim_model.get_score(sentence1, sentence2)
print("{} \t\t {} \t\t Score: {:.4f}".format(sentence1, sentence2, score))

return score


if __name__ == '__main__':
examples = [
['如何更换花呗绑定银行卡', '花呗更改绑定银行卡'],
['我在北京打篮球', '我是北京人,我喜欢篮球'],
['一个女人在看书。', '一个女人在揉面团'],
['一个男人在车库里举重。', '一个人在举重。'],
]
input1 = gr.inputs.Textbox(lines=2, placeholder="Enter First Sentence")
input2 = gr.inputs.Textbox(lines=2, placeholder="Enter Second Sentence")

output_text = gr.outputs.Textbox()
gr.Interface(ai_text,
inputs=[input1, input2],
outputs=[output_text],
theme="grass",
title="Chinese Text to Vector Model shibing624/text2vec-base-chinese",
description="Copy or input Chinese text here. Submit and the machine will calculate the cosine score.",
article="Link to <a href='https://github.com/shibing624/text2vec' style='color:blue;' target='_blank\'>Github REPO</a>",
examples=examples
).launch()
36 changes: 36 additions & 0 deletions examples/ngram_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
"""

import sys

sys.path.append('..')
from text2vec.ngram import NGram


def compute_emb(model):
# Embed a list of sentences
sentences = ['卡',
'银行卡',
'如何更换花呗绑定银行卡',
'花呗更改绑定银行卡',
]
sentence_embeddings = model.encode(sentences)

print(type(sentence_embeddings), sentence_embeddings.shape)

# The result is a list of sentence embeddings as numpy arrays
for sentence, embedding in zip(sentences, sentence_embeddings):
print("Sentence:", sentence)
print("Embedding shape:", embedding.shape)
print()


if __name__ == "__main__":
ngram_model = NGram()
r = ngram_model.encode('兄弟们冲呀')
print(type(r), r.shape, r)

compute_emb(ngram_model)
10 changes: 3 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
jieba>=0.39
gensim>=4.0.0
loguru
transformers>=4.6.0,<5.0.0
tokenizers>=0.10.3
tqdm
torch>=1.6.0
scikit-learn
#sentencepiece
#huggingface-hub
transformers>=4.6.0,<5.0.0
numpy
#sentence-transformers
#pyemd
scikit-learn
gensim>=4.0.0
1 change: 1 addition & 0 deletions text2vec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from text2vec.sbert import SBert, semantic_search, cos_sim
from text2vec.bm25 import BM25
from text2vec.similarity import Similarity, SimType, EmbType
from text2vec.ngram import NGram
86 changes: 86 additions & 0 deletions text2vec/ngram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
"""
import os
from typing import List, Union
from loguru import logger
import numpy as np
from text2vec.utils.get_file import get_file


class NGram:
def __init__(self, model_name_or_path=None, cache_folder=os.path.expanduser('~/.pycorrector/datasets/')):
if model_name_or_path and os.path.exists(model_name_or_path):
logger.info('Load kenlm language model:{}'.format(model_name_or_path))
language_model_path = model_name_or_path
else:
# 语言模型 2.95GB
get_file(
'zh_giga.no_cna_cmn.prune01244.klm',
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
extract=True,
cache_subdir=cache_folder,
verbose=1)
language_model_path = os.path.join(cache_folder, 'zh_giga.no_cna_cmn.prune01244.klm')
try:
import kenlm
except ImportError:
raise ImportError('Kenlm not installed, use "pip install kenlm".')
self.lm = kenlm.Model(language_model_path)
logger.debug('Loaded language model: %s.' % language_model_path)

def ngram_score(self, sentence: str):
"""
取n元文法得分
:param sentence: str, 输入的句子
:return:
"""
return self.lm.score(' '.join(sentence), bos=False, eos=False)

def perplexity(self, sentence: str):
"""
取语言模型困惑度得分,越小句子越通顺
:param sentence: str, 输入的句子
:return:
"""
return self.lm.perplexity(' '.join(sentence))

def encode(self, sentences: Union[List[str], str]):
"""
将句子转换成ngram特征向量
"""
if self.lm is None:
raise ValueError('No model for embed sentence')

input_is_string = False
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
sentences = [sentences]
input_is_string = True

all_embeddings = []
for sentence in sentences:
ngram_avg_scores = []
for n in [2, 3]:
scores = []
for i in range(len(sentence) - n + 1):
word = sentence[i:i + n]
score = self.ngram_score(word)
scores.append(score)
if scores:
# 移动窗口补全得分
for _ in range(n - 1):
scores.insert(0, scores[0])
scores.append(scores[-1])
avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]
else:
avg_scores = np.zeros(len(sentence), dtype=float)
ngram_avg_scores.append(avg_scores)
# 取拼接后的n-gram平均得分
sent_scores = np.average(np.array(ngram_avg_scores), axis=0)
all_embeddings.append(sent_scores)
all_embeddings = np.asarray(all_embeddings, dtype=object)
if input_is_string:
all_embeddings = all_embeddings[0]
return all_embeddings
6 changes: 3 additions & 3 deletions text2vec/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
pwd_path = os.path.abspath(os.path.dirname(__file__))
default_stopwords_file = os.path.join(pwd_path, 'data/stopwords.txt')
USER_DATA_DIR = os.path.expanduser('~/.text2vec/datasets/')
os.makedirs(USER_DATA_DIR, exist_ok=True)


def load_stopwords(file_path):
Expand Down Expand Up @@ -73,9 +72,10 @@ def __init__(self, model_name_or_path: str = 'w2v-light-tencent-chinese',
untar_filename = model_dict.get('untar_filename')
model_path = os.path.join(cache_folder, untar_filename)
if not os.path.exists(model_path):
os.makedirs(cache_folder, exist_ok=True)
get_file(tar_filename, url, extract=True,
cache_dir=USER_DATA_DIR,
cache_subdir=USER_DATA_DIR,
cache_dir=cache_folder,
cache_subdir=cache_folder,
verbose=1)
t0 = time.time()
w2v = KeyedVectors.load_word2vec_format(model_path, **self.w2v_kwargs)
Expand Down

0 comments on commit cfeed78

Please sign in to comment.