Skip to content

Commit

Permalink
Update config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
localminimum committed Apr 24, 2018
1 parent 917c110 commit 3665bc6
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ python config.py --mode debug/train/test/demo

To evaluate the model with official code, run
```bash
python evaluate-v1.1.py ~/data/squad/dev-v1.1.json train/answer/answer.json
python evaluate-v1.1.py ~/data/squad/dev-v1.1.json train/{model_name}/answer/answer.json
```

The default directory for tensorboard log file is `train/{model_name}/event`
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
flags.DEFINE_integer("capacity", 15000, "Batch size of dataset shuffle")
flags.DEFINE_integer("num_threads", 4, "Number of threads in input pipeline")
flags.DEFINE_boolean("is_bucket", False, "build bucket batch iterator or not")
flags.DEFINE_integer("bucket_range", [40, 401, 40], "the range of bucket")
flags.DEFINE_list("bucket_range", [40, 401, 40], "the range of bucket")

flags.DEFINE_integer("batch_size", 32, "Batch size")
flags.DEFINE_integer("num_steps", 60000, "Number of steps")
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, model, config):
while 1:
sleep(.1)
except KeyboardInterrupt:
print "Closing server..."
print("Closing server...")
run_event.clear()

def demo_backend(self, model, config, run_event):
Expand Down
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ def train(config):
saver = tf.train.Saver()
train_handle = sess.run(train_iterator.string_handle())
dev_handle = sess.run(dev_iterator.string_handle())
if os.path.exists(os.path.join(config.save_dir, "checkpoint")):
saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
global_step = max(sess.run(model.global_step), 1)

for _ in tqdm(range(1, config.num_steps + 1)):
for _ in tqdm(range(global_step, config.num_steps + 1)):
global_step = sess.run(model.global_step) + 1
loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
handle: train_handle, model.dropout: config.dropout})
Expand Down

0 comments on commit 3665bc6

Please sign in to comment.