Skip to content

Commit

Permalink
ptb_platform.py 重命名为 ptb_caicloud_taas.py,并且添加了相关的注释说明。
Browse files Browse the repository at this point in the history
  • Loading branch information
lienhua34 committed Apr 12, 2017
1 parent 25c2408 commit 38dec3f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*~
*.pyc
*.pyc
.DS_Store
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def perplexity_compute_fn(session):
def train_step(session, model, eval_op=None, verbose=False):
"""针对每个 batch size 中被截断的序列进行一次训练操作。
"""

# 注:
# 与原生的 ptb_word_lm.py 中的模型训练的区别:
# 这里每次对 batch size 中的截断序列进行模型训练的时候会进行一次初始化状态。
# 而 ptb_word_lm.py 中的 run_epoch 函数是在每 epoch 开始训练之前进行一次状态
# 初始化。
state = session.run(model.initial_state)

fetches = {
Expand All @@ -110,12 +116,7 @@ def train_step(session, model, eval_op=None, verbose=False):
if eval_op is not None:
fetches["eval_op"] = eval_op

feed_dict = {}
for i, (c, h) in enumerate(model.initial_state):
feed_dict[c] = state[i].c
feed_dict[h] = state[i].h

vals = session.run(fetches, feed_dict)
vals = session.run(fetches)
cost = vals["cost"]
state = vals["final_state"]
iters = model.input.num_steps
Expand All @@ -140,6 +141,13 @@ def train_fn(session, num_global_step):
print("Epoch: {0}, Global step: {1}, Learning rate: {2:.3f}".format(epoch + 1, num_global_step+1, session.run(mtrain.lr)))

start_time = time.time()

# 注:
# 这里调用 train_step 函数来对每个 batch size 被截断的序列进行模型训练。
# 我们也可以直接调用 ptb_word_lm.py 中的 run_epoch 函数来执行一个 epoch 的模型训练,
# 但是,执行一个 epoch 的模型训练,global_step 会变化很多。于是,在启动分布式模型训练
# 任务的时候设置的最大训练轮数 max_steps 将不能正常起作用,最终停止模型训练的时候,
# 可能实际的 global_step 已经超过 max_steps 很多。
train_perplexity = train_step(session, mtrain, eval_op=mtrain.train_op)

# 计算训练速度(每秒处理多少个单词)。
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
rm -rf /tmp/ptb_saved_model
rm -rf /tmp/ptb

python ptb_platform.py \
python ptb_caicloud_taas.py \
--max_steps=10000 \
--save_checkpoints_secs=3 \
--save_summaries_secs=3 \
Expand Down

0 comments on commit 38dec3f

Please sign in to comment.