From 0c90c9e4775352f093b747783eb5dc8491c94541 Mon Sep 17 00:00:00 2001
From: xuming624 <xuming624@qq.com>
Date: Fri, 27 Dec 2019 18:47:00 +0800
Subject: [PATCH] update sim result.

---
 README.md                   | 10 ++++++----
 examples/similarity_demo.py |  5 +++--
 text2vec/similarity.py      | 18 +++++++++++-------
 3 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/README.md b/README.md
index 3caf716..f65cf45 100644
--- a/README.md
+++ b/README.md
@@ -230,16 +230,18 @@ b = '花呗更改绑定银行卡'
 c = '我什么时候开通了花呗'
 
 corpus = [a, b, c]
+print(corpus)
 search_sim = SearchSimilarity(corpus=corpus)
 
-print(search_sim.get_scores(query=a))
-print(search_sim.get_similarities(query=a))
+print(a, 'scores:', search_sim.get_scores(query=a))
+print(a, 'rank similarities:', search_sim.get_similarities(query=a))
 ```
 
 output:
 ```
-[ 0.9527457  -0.07449248 -0.03204909]
-[['如何', '更换', '花', '呗', '绑定', '银行卡'], ['我', '什么', '时候', '开通', '了', '花', '呗'], ['花', '呗', '更改', '绑定', '银行卡']]
+['如何更换花呗绑定银行卡', '花呗更改绑定银行卡', '我什么时候开通了花呗']
+如何更换花呗绑定银行卡 scores: [ 0.9527457  -0.07449248 -0.03204909]
+如何更换花呗绑定银行卡 rank similarities: ['如何更换花呗绑定银行卡', '我什么时候开通了花呗', '花呗更改绑定银行卡']
 ```
 
 
diff --git a/examples/similarity_demo.py b/examples/similarity_demo.py
index 3c3f77e..fdaffbb 100644
--- a/examples/similarity_demo.py
+++ b/examples/similarity_demo.py
@@ -23,7 +23,8 @@
 from text2vec import SearchSimilarity
 
 corpus = [a, b, c]
+print(corpus)
 search_sim = SearchSimilarity(corpus=corpus)
 
-print(search_sim.get_scores(query=a))
-print(search_sim.get_similarities(query=a))
+print(a, 'scores:', search_sim.get_scores(query=a))
+print(a, 'rank similarities:', search_sim.get_similarities(query=a))
diff --git a/text2vec/similarity.py b/text2vec/similarity.py
index 053032b..de6f136 100644
--- a/text2vec/similarity.py
+++ b/text2vec/similarity.py
@@ -4,6 +4,8 @@
 @description: 
 """
 
+import numpy as np
+
 from text2vec.algorithm.distance import cosine_distance
 from text2vec.algorithm.rank_bm25 import BM25Okapi
 from text2vec.utils.logger import get_logger
@@ -87,24 +89,26 @@ def init(self):
             if isinstance(self.corpus, str):
                 self.corpus = [self.corpus]
 
-            self.corpus_seg = [self.tokenizer.tokenize(i) for i in self.corpus]
-            self.bm25_instance = BM25Okapi(corpus=self.corpus_seg)
+            self.corpus_seg = {k: self.tokenizer.tokenize(k) for k in self.corpus}
+            self.bm25_instance = BM25Okapi(corpus=list(self.corpus_seg.values()))
 
     def get_similarities(self, query, n=5):
         """
         Get similarity between `query` and this docs.
         :param query: str
         :param n: int, num_best
-        :return: float scores
+        :return: result, dict, float scores, docs rank
         """
-        self.init()
-        tokens = self.tokenizer.tokenize(query)
-        return self.bm25_instance.get_top_n(tokens, self.corpus_seg, n=n)
+        scores = self.get_scores(query)
+        rank_n = np.argsort(scores)[::-1]
+        if n > 0:
+            rank_n = rank_n[:n]
+        return [self.corpus[i] for i in rank_n]
 
     def get_scores(self, query):
         """
         Get scores between query and docs
-        :param query:
+        :param query: input str
         :return: numpy array, scores for query between docs
         """
         self.init()