Skip to content

Commit

Permalink
添加 ptb 样例代码的模型评测计算
Browse files Browse the repository at this point in the history
  • Loading branch information
lienhua34 committed Apr 10, 2017
1 parent 6820604 commit 25c2408
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions caicloud.tensorflow/caicloud/clever/examples/ptb/ptb_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def model_fn(sync, num_replicas):
with tf.variable_scope("Model", reuse=True, initializer=initializer):
mtest = ptb_word_lm.PTBModel(is_training=False, config=test_config, input_=test_input)

# 定义模型评测的计算方法
# 使用验证数据集来计算在验证模型 mvalida 上计算模型的训练效果。
def perplexity_compute_fn(session):
valid_perplexity = ptb_word_lm.run_epoch(session, mvalid)
return valid_perplexity
model_metric_ops = {
"perplexity": perplexity_compute_fn
}

# 定义模型导出的配置。
# 因为训练模型的 input_data 的维度包含了 batch size,不适用于最终的 Serving。
# 而测试模型和训练模型是共享变量的,所以可以直接使用测试模型(mtest)的 input_data 和
Expand All @@ -84,7 +93,10 @@ def model_fn(sync, num_replicas):
export_dir=FLAGS.save_path,
input_tensors={"input": mtest.input_data},
output_tensors={"logits": mtest.logits})
return dist_base.ModelFnHandler(model_export_spec=model_export_spec)

return dist_base.ModelFnHandler(
model_export_spec=model_export_spec,
model_metric_ops=model_metric_ops)

def train_step(session, model, eval_op=None, verbose=False):
"""针对每个 batch size 中被截断的序列进行一次训练操作。
Expand Down Expand Up @@ -143,7 +155,7 @@ def train_fn(session, num_global_step):
if _local_step % 100 == 0:
print("[Evaluation] Start to evaluate model with evaluation dataset ...")
valid_perplexity = ptb_word_lm.run_epoch(session, mvalid)
print("[Evaluation] Epoch {0}, Global step: {1},"
print("[Evaluation] Epoch {0}, Global step: {1}, "
"Valid Perplexity: {2:.3f}".format(
epoch+1, num_global_step+1, valid_perplexity))

Expand All @@ -153,7 +165,7 @@ def after_train_hook(session):
# 使用完整的测试数据集来测试训练得到的模型的性能。
print("[Test] Start to test model with test dataset ...")
test_perplexity = ptb_word_lm.run_epoch(session, mtest)
print("[Test] Perplexity: %.3f" % test_perplexity)
print("[Test] Perplexity: {0:.3f}".format(test_perplexity))

if __name__ == '__main__':
distTfRunner = dist_base.DistTensorflowRunner(
Expand Down

0 comments on commit 25c2408

Please sign in to comment.