forked from shibing624/text2vec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a6129de
commit cfeed78
Showing
6 changed files
with
170 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: pip install gradio | ||
""" | ||
|
||
import gradio as gr | ||
from text2vec import Similarity | ||
|
||
# 中文句向量模型(CoSENT) | ||
sim_model = Similarity(model_name_or_path='shibing624/text2vec-base-chinese', | ||
similarity_type='cosine', embedding_type='sbert') | ||
|
||
|
||
def ai_text(sentence1, sentence2): | ||
score = sim_model.get_score(sentence1, sentence2) | ||
print("{} \t\t {} \t\t Score: {:.4f}".format(sentence1, sentence2, score)) | ||
|
||
return score | ||
|
||
|
||
if __name__ == '__main__': | ||
examples = [ | ||
['如何更换花呗绑定银行卡', '花呗更改绑定银行卡'], | ||
['我在北京打篮球', '我是北京人,我喜欢篮球'], | ||
['一个女人在看书。', '一个女人在揉面团'], | ||
['一个男人在车库里举重。', '一个人在举重。'], | ||
] | ||
input1 = gr.inputs.Textbox(lines=2, placeholder="Enter First Sentence") | ||
input2 = gr.inputs.Textbox(lines=2, placeholder="Enter Second Sentence") | ||
|
||
output_text = gr.outputs.Textbox() | ||
gr.Interface(ai_text, | ||
inputs=[input1, input2], | ||
outputs=[output_text], | ||
theme="grass", | ||
title="Chinese Text to Vector Model shibing624/text2vec-base-chinese", | ||
description="Copy or input Chinese text here. Submit and the machine will calculate the cosine score.", | ||
article="Link to <a href='https://github.com/shibing624/text2vec' style='color:blue;' target='_blank\'>Github REPO</a>", | ||
examples=examples | ||
).launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: | ||
""" | ||
|
||
import sys | ||
|
||
sys.path.append('..') | ||
from text2vec.ngram import NGram | ||
|
||
|
||
def compute_emb(model): | ||
# Embed a list of sentences | ||
sentences = ['卡', | ||
'银行卡', | ||
'如何更换花呗绑定银行卡', | ||
'花呗更改绑定银行卡', | ||
] | ||
sentence_embeddings = model.encode(sentences) | ||
|
||
print(type(sentence_embeddings), sentence_embeddings.shape) | ||
|
||
# The result is a list of sentence embeddings as numpy arrays | ||
for sentence, embedding in zip(sentences, sentence_embeddings): | ||
print("Sentence:", sentence) | ||
print("Embedding shape:", embedding.shape) | ||
print() | ||
|
||
|
||
if __name__ == "__main__": | ||
ngram_model = NGram() | ||
r = ngram_model.encode('兄弟们冲呀') | ||
print(type(r), r.shape, r) | ||
|
||
compute_emb(ngram_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,9 @@ | ||
jieba>=0.39 | ||
gensim>=4.0.0 | ||
loguru | ||
transformers>=4.6.0,<5.0.0 | ||
tokenizers>=0.10.3 | ||
tqdm | ||
torch>=1.6.0 | ||
scikit-learn | ||
#sentencepiece | ||
#huggingface-hub | ||
transformers>=4.6.0,<5.0.0 | ||
numpy | ||
#sentence-transformers | ||
#pyemd | ||
scikit-learn | ||
gensim>=4.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: | ||
""" | ||
import os | ||
from typing import List, Union | ||
from loguru import logger | ||
import numpy as np | ||
from text2vec.utils.get_file import get_file | ||
|
||
|
||
class NGram: | ||
def __init__(self, model_name_or_path=None, cache_folder=os.path.expanduser('~/.pycorrector/datasets/')): | ||
if model_name_or_path and os.path.exists(model_name_or_path): | ||
logger.info('Load kenlm language model:{}'.format(model_name_or_path)) | ||
language_model_path = model_name_or_path | ||
else: | ||
# 语言模型 2.95GB | ||
get_file( | ||
'zh_giga.no_cna_cmn.prune01244.klm', | ||
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', | ||
extract=True, | ||
cache_subdir=cache_folder, | ||
verbose=1) | ||
language_model_path = os.path.join(cache_folder, 'zh_giga.no_cna_cmn.prune01244.klm') | ||
try: | ||
import kenlm | ||
except ImportError: | ||
raise ImportError('Kenlm not installed, use "pip install kenlm".') | ||
self.lm = kenlm.Model(language_model_path) | ||
logger.debug('Loaded language model: %s.' % language_model_path) | ||
|
||
def ngram_score(self, sentence: str): | ||
""" | ||
取n元文法得分 | ||
:param sentence: str, 输入的句子 | ||
:return: | ||
""" | ||
return self.lm.score(' '.join(sentence), bos=False, eos=False) | ||
|
||
def perplexity(self, sentence: str): | ||
""" | ||
取语言模型困惑度得分,越小句子越通顺 | ||
:param sentence: str, 输入的句子 | ||
:return: | ||
""" | ||
return self.lm.perplexity(' '.join(sentence)) | ||
|
||
def encode(self, sentences: Union[List[str], str]): | ||
""" | ||
将句子转换成ngram特征向量 | ||
""" | ||
if self.lm is None: | ||
raise ValueError('No model for embed sentence') | ||
|
||
input_is_string = False | ||
if isinstance(sentences, str) or not hasattr(sentences, '__len__'): | ||
sentences = [sentences] | ||
input_is_string = True | ||
|
||
all_embeddings = [] | ||
for sentence in sentences: | ||
ngram_avg_scores = [] | ||
for n in [2, 3]: | ||
scores = [] | ||
for i in range(len(sentence) - n + 1): | ||
word = sentence[i:i + n] | ||
score = self.ngram_score(word) | ||
scores.append(score) | ||
if scores: | ||
# 移动窗口补全得分 | ||
for _ in range(n - 1): | ||
scores.insert(0, scores[0]) | ||
scores.append(scores[-1]) | ||
avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))] | ||
else: | ||
avg_scores = np.zeros(len(sentence), dtype=float) | ||
ngram_avg_scores.append(avg_scores) | ||
# 取拼接后的n-gram平均得分 | ||
sent_scores = np.average(np.array(ngram_avg_scores), axis=0) | ||
all_embeddings.append(sent_scores) | ||
all_embeddings = np.asarray(all_embeddings, dtype=object) | ||
if input_is_string: | ||
all_embeddings = all_embeddings[0] | ||
return all_embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters