diff --git a/tests/emb_w2v_test.py b/tests/emb_w2v_test.py index ad74090..15abf42 100644 --- a/tests/emb_w2v_test.py +++ b/tests/emb_w2v_test.py @@ -6,9 +6,19 @@ from text2vec.embeddings.word_embedding import WordEmbedding if __name__ == '__main__': + import text2vec + + char = ',' + result = text2vec.encode(char) + print(type(result)) + print(char, result) + + b = WordEmbedding() data1 = '你 好 啊'.split(' ') r = b.embed([data1], True) print(r) print(r.shape) + + diff --git a/text2vec/embeddings/word_embedding.py b/text2vec/embeddings/word_embedding.py index 230ea77..1490ab9 100644 --- a/text2vec/embeddings/word_embedding.py +++ b/text2vec/embeddings/word_embedding.py @@ -178,7 +178,10 @@ def embed(self, emb.append(self.w2v[word]) count += 1 tensor_x = np.array(emb).sum(axis=0) # 纵轴相加 - avg_tensor_x = np.divide(tensor_x, count) + if count > 0: + avg_tensor_x = np.divide(tensor_x, count) + else: + avg_tensor_x = 0.0 embeds.append(avg_tensor_x) embeds = np.array(embeds) if debug: