Skip to content

Commit

Permalink
prepare train 40 epoch with this params.json
Browse files Browse the repository at this point in the history
  • Loading branch information
ximinng committed May 1, 2019
1 parent 322d410 commit aafcf7c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 15 deletions.
8 changes: 4 additions & 4 deletions src/params.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
{
"bidirectional": true,
"use_residual": false,
"use_dropout": false,
"time_major": false,
"use_residual": true,
"use_dropout": true,
"time_major": true,
"cell_type": "lstm",
"depth": 2,
"depth": 4,
"attention_type": "Bahdanau",
"hidden_units": 128,
"optimizer": "adam",
Expand Down
4 changes: 2 additions & 2 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def test(params):
)

# 读取模型路径
# save_path = './model/s2ss_chatbot_anti.ckpt'
save_path = './model/s2ss_chatbot.ckpt'
save_path = './model/s2ss_chatbot_anti.ckpt'
# save_path = './model/s2ss_chatbot.ckpt'

tf.reset_default_graph()
model_pred = SequenceToSequence(
Expand Down
2 changes: 1 addition & 1 deletion src/train_anti.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test(params):
x_data, y_data = pickle.load(open('chatbot.pkl', 'rb'))
ws = pickle.load(open('ws.pkl', 'rb'))

n_epoch = 2 # 训练轮次
n_epoch = 40 # 训练轮次
batch_size = 128
steps = int(len(x_data) / batch_size) + 1

Expand Down
9 changes: 1 addition & 8 deletions src/web.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
Description : train_anti 测试
Description : web interface
Author : xxm
"""
import sys
Expand All @@ -18,9 +18,6 @@ def test(params, infos):
x_data, _ = pickle.load(open('chatbot.pkl', 'rb'))
ws = pickle.load(open('ws.pkl', 'rb'))

for x in x_data[:5]:
print(' '.join(x))

config = tf.ConfigProto(
device_count={'CPU': 1, 'GPU': 0},
allow_soft_placement=True,
Expand Down Expand Up @@ -50,17 +47,13 @@ def test(params, infos):
x, xl = next(bar)
x = np.flip(x, axis=1)

print(x, xl)

pred = model_pred.predict(
sess,
np.array(x),
np.array(xl)
)
print(pred)

print(ws.inverse_transform(x[0]))

for p in pred:
ans = ws.inverse_transform(p)
print(ans)
Expand Down

0 comments on commit aafcf7c

Please sign in to comment.