Skip to content

Commit

Permalink
update seq2seq config.
Browse files Browse the repository at this point in the history
  • Loading branch information
xuming06 committed Nov 8, 2018
1 parent 1ecae38 commit 85653d2
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ python3 infer.py
- [x] 整理中文纠错训练数据,使用seq2seq做深度中文纠错模型
- [x] 添加中文语法错误检测及纠正能力
- [x] 规则方法添加用户自定义纠错集,并将其纠错优先度调为最高
- [ ] seq2seq_attention 添加dropout,减少过拟合
- [x] seq2seq_attention 添加dropout,减少过拟合


## 参考
Expand Down
1 change: 1 addition & 0 deletions pycorrector/seq2seq_attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
epochs = 40
rnn_hidden_dim = 128
maxlen = 400
dropout = 0.0
use_gpu = False

if not os.path.exists(output_dir):
Expand Down
4 changes: 2 additions & 2 deletions pycorrector/seq2seq_attention/preprocess_short_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def parse_xml_file(path):
texts = split_2_short_text(text)
corrections = split_2_short_text(correction)
if len(texts) != len(corrections):
print('error:' + text + '\t' + correction)
# print('error:' + text + '\t' + correction)
continue
for i in range(len(texts)):
if len(texts[i]) > 40:
print('error:' + texts[i] + '\t' + corrections[i])
# print('error:' + texts[i] + '\t' + corrections[i])
continue
source = segment(texts[i], cut_type='char')
target = segment(corrections[i], cut_type='char')
Expand Down
6 changes: 4 additions & 2 deletions pycorrector/seq2seq_attention/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_validation_data(input_texts, target_texts, char2id, maxlen=400):


def train(train_path='', test_path='', save_vocab_path='', attn_model_path='',
batch_size=64, epochs=100, maxlen=400, hidden_dim=128, use_gpu=False):
batch_size=64, epochs=100, maxlen=400, hidden_dim=128, dropout=0.2, use_gpu=False):
data_reader = CGEDReader(train_path)
input_texts, target_texts = data_reader.build_dataset(train_path)
test_input_texts, test_target_texts = data_reader.build_dataset(test_path)
Expand All @@ -69,7 +69,8 @@ def train(train_path='', test_path='', save_vocab_path='', attn_model_path='',
model = Seq2seqAttnModel(chars,
attn_model_path=attn_model_path,
hidden_dim=hidden_dim,
use_gpu=use_gpu).build_model()
use_gpu=use_gpu,
dropout=dropout).build_model()
evaluator = Evaluate(model, attn_model_path, char2id, id2char, maxlen)
model.fit_generator(data_generator(input_texts, target_texts, char2id, batch_size, maxlen),
steps_per_epoch=(len(input_texts) + batch_size - 1) // batch_size,
Expand All @@ -87,4 +88,5 @@ def train(train_path='', test_path='', save_vocab_path='', attn_model_path='',
epochs=config.epochs,
maxlen=config.maxlen,
hidden_dim=config.rnn_hidden_dim,
dropout=config.dropout,
use_gpu=config.use_gpu)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ scikit-learn
pypinyin
kenlm==0.0.0
jieba
tensorflow>=1.9.0
tensorflow==1.9.0
keras>=2.1.5

0 comments on commit 85653d2

Please sign in to comment.