Skip to content

Commit

Permalink
modify feed_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
InsaneLife committed May 17, 2019
1 parent 34ba787 commit 8b54a0c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ __pycache__/
model/

# 通用
envi/
envi/
Summaries/
10 changes: 5 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def __init__(self):
unk = '[UNK]'
pad = '[PAD]'
vocab_path = './data/vocab.txt'
file_train = './data/oppo_round1_train_20180929_mini.txt'
# file_train = './data/oppo_round1_train_20180929.txt'
file_vali = './data/oppo_round1_vali_20180929_mini.txt'
# file_vali = './data/oppo_round1_vali_20180929.txt'
file_train = './data/oppo_round1_train_20180929.mini'
file_train = './data/oppo_round1_train_20180929.txt'
file_vali = './data/oppo_round1_vali_20180929.mini'
file_vali = './data/oppo_round1_vali_20180929.txt'
max_seq_len = 10
hidden_size_rnn = 100
use_stack_rnn = False
learning_rate = 0.01
learning_rate = 0.001
# max_steps = 8000
num_epoch = 100
summaries_dir = './Summaries/'
Expand Down
14 changes: 6 additions & 8 deletions dssm_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
TensorFlow=1.2.1
"""

import pandas as pd
from scipy import sparse
import collections

import random
import time
import numpy as np
Expand Down Expand Up @@ -275,8 +273,8 @@ def pull_batch(data_map, batch_id):
return query_in, doc_positive_in, doc_negative_in, query_len, doc_positive_len, doc_negative_len


def feed_dict(on_training, batch_id, drop_prob):
query_in, doc_positive_in, doc_negative_in, query_seq_len, pos_seq_len, neg_seq_len = pull_batch(data_vali,
def feed_dict(on_training, data_set, batch_id, drop_prob):
query_in, doc_positive_in, doc_negative_in, query_seq_len, pos_seq_len, neg_seq_len = pull_batch(data_set,
batch_id)
query_len = len(query_in)
query_seq_len = [conf.max_seq_len] * query_len
Expand Down Expand Up @@ -305,12 +303,12 @@ def feed_dict(on_training, batch_id, drop_prob):
random.shuffle(batch_ids)
for batch_id in batch_ids:
# print(batch_id)
sess.run(train_step, feed_dict=feed_dict(True, batch_id, 0.5))
sess.run(train_step, feed_dict=feed_dict(True, data_train, batch_id, 0.5))
end = time.time()
# train loss
epoch_loss = 0
for i in range(train_epoch_steps):
loss_v = sess.run(loss, feed_dict=feed_dict(False, i, 1))
loss_v = sess.run(loss, feed_dict=feed_dict(False, data_train, i, 1))
epoch_loss += loss_v

epoch_loss /= (train_epoch_steps)
Expand All @@ -323,7 +321,7 @@ def feed_dict(on_training, batch_id, drop_prob):
start = time.time()
epoch_loss = 0
for i in range(vali_epoch_steps):
loss_v = sess.run(loss, feed_dict=feed_dict(False, i, 1))
loss_v = sess.run(loss, feed_dict=feed_dict(False, data_vali, i, 1))
epoch_loss += loss_v
epoch_loss /= (vali_epoch_steps)
test_loss = sess.run(loss_summary, feed_dict={average_loss: epoch_loss})
Expand Down

0 comments on commit 8b54a0c

Please sign in to comment.