From 74e38ba61151ceb3b0f9449a1bb09faac857c218 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Tue, 1 Mar 2022 12:35:08 +0800 Subject: [PATCH] update dataset.. --- CITATION.cff | 1 + README.md | 35 ++++++++++++++++++++++++++++++++++- text2vec/similarity.py | 4 ++-- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 11f4771..0e830c1 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,6 +3,7 @@ message: "If you use this software, please cite it as below." authors: - family-names: "Xu" given-names: "Ming" + orcid: "https://orcid.org/0000-0003-3402-7159" title: "Text2vec: Text to vector toolkit" url: "https://github.com/shibing624/text2vec" data-released: 2022-02-27 diff --git a/README.md b/README.md index 630ce98..a1264a7 100644 --- a/README.md +++ b/README.md @@ -179,9 +179,42 @@ python3 setup.py install ``` ### 数据集 +中文语义匹配数据集已经上传到huggingface datasets [https://huggingface.co/datasets/shibing624/nli_zh](https://huggingface.co/datasets/shibing624/nli_zh) + +数据集使用示例: +```shell +pip3 install datasets +``` + +```python +from datasets import load_dataset + +dataset = load_dataset("shibing624/nli_zh", "STS-B") +print(dataset) +print(dataset['test'][0]) +``` + +output: +```shell +DatasetDict({ + train: Dataset({ + features: ['sentence1', 'sentence2', 'label'], + num_rows: 5231 + }) + validation: Dataset({ + features: ['sentence1', 'sentence2', 'label'], + num_rows: 1458 + }) + test: Dataset({ + features: ['sentence1', 'sentence2', 'label'], + num_rows: 1361 + }) +}) +{'sentence1': '一个女孩在给她的头发做发型。', 'sentence2': '一个女孩在梳头。', 'label': 2} +``` + 常见中文语义匹配数据集,包含[ATEC](https://github.com/IceFlameWorm/NLP_Datasets/tree/master/ATEC)、[BQ](http://icrc.hitsz.edu.cn/info/1037/1162.htm)、[LCQMC](http://icrc.hitsz.edu.cn/Article/show/171.html)、[PAWSX](https://arxiv.org/abs/1908.11828)、[STS-B](https://github.com/pluto-junzeng/CNSD)共5个任务。 可以从数据集对应的链接自行下载,也可以从[百度网盘(提取码:qkt6)](https://pan.baidu.com/s/1d6jSiU1wHQAEMWJi7JJWCQ)下载。 - 其中senteval_cn目录是评测数据集汇总,senteval_cn.zip是senteval目录的打包,两者下其一就好。 # Usage diff --git a/text2vec/similarity.py b/text2vec/similarity.py index 3101601..1361d15 100644 --- a/text2vec/similarity.py +++ b/text2vec/similarity.py @@ -102,7 +102,7 @@ def get_score(self, sentence1: str, sentence2: str) -> float: def get_scores( self, sentences1: List[str], sentences2: List[str], only_aligned: bool = False - ) -> Union[List[Tensor], ndarray, Tensor, None]: + ) -> ndarray: """ Get similarity scores between sentences1 and sentences2 :param sentences1: list, sentence1 list @@ -111,7 +111,7 @@ def get_scores( :return: return: Matrix with res[i][j] = cos_sim(a[i], b[j]) """ if not sentences1 or not sentences2: - return None + return np.array([]) if only_aligned and len(sentences1) != len(sentences2): logger.warning('Sentences size not equal, auto set is_aligned=False') only_aligned = False