-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtest.py
112 lines (73 loc) · 3.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
"""
Created on Thu May 17 10:26:42 2018
@author: shen1994
"""
import gensim
import numpy as np
from encoder2decoder import build_model
from data_process import DataProcess
def data_to_padding_ids(text_list):
data_process = DataProcess(use_word2cut=True)
enc_vocab = data_process.read_vocabulary(data_process.enc_vocab_file)
enc_padding_ids_list = []
for text in text_list:
words = data_process.text_cut_object.cut([text.strip()])
words_list = words[0].strip().split()
enc_ids = [enc_vocab.get(word, data_process.__UNK__) for word in words_list]
if len(enc_ids) > data_process.enc_input_length:
enc_ids = enc_ids[:data_process.enc_input_length]
enc_length = len(enc_ids)
enc_padding_ids = []
enc_padding_ids.extend([0] * (data_process.enc_input_length - enc_length))
enc_padding_ids.extend([int(enc_ids[enc_length - l - 1]) for l in range(enc_length)])
enc_padding_ids_list.append(np.array(enc_padding_ids))
return np.array(enc_padding_ids_list)
def calculate_mse(src_vec, des_vec):
data_process = DataProcess(use_word2cut=False)
std_number = np.std(des_vec)
if (std_number - data_process.epsilon) < 0:
norm_des_vec = np.zeros(data_process.dec_embedding_length)
else:
norm_des_vec = (des_vec - np.mean(des_vec)) / std_number
err = np.square(src_vec - norm_des_vec)
mse = np.sum(err)
return mse
def predict_text(model, enc_embedding):
data_process = DataProcess(use_word2cut=False)
dec_vec_model = gensim.models.Word2Vec.load(r'model/decoder_vector.m')
dec_useful_words = tuple(dec_vec_model.wv.vocab.keys())
prediction = model.predict_on_batch(enc_embedding)
prediction_words_list = []
for elem in prediction:
prediction_words = []
for vec in elem:
dec_dis_list = []
mse = calculate_mse(vec, np.zeros(data_process.dec_embedding_length))
dec_dis_list.append(mse)
for dec_word in dec_useful_words:
mse = calculate_mse(vec, dec_vec_model.wv[dec_word])
dec_dis_list.append(mse)
index = np.argmin(dec_dis_list)
if index == 0:
word = data_process.__VOCAB__[0]
else:
word = dec_useful_words[index - 1]
prediction_words.append(word)
prediction_words_list.append(prediction_words)
return prediction_words_list
def load_model(model_path):
model = build_model(training=False)
model.load_weights(model_path)
return model
def common_prediction(model, text):
padding_ids = data_to_padding_ids(text)
words = predict_text(model, padding_ids)
return words
def run():
text = [u"我真的好喜欢你,你认为呢?"]
model = load_model("model/seq2seq_model_weights.h5")
prediction_words = common_prediction(model, text)
print(prediction_words)
if __name__ == "__main__":
run()