Skip to content

Commit

Permalink
Merge pull request caicloud#20 from lienhua34/tf-run-config
Browse files Browse the repository at this point in the history
Base 框架添加关闭自动计算 summary 机制的方法;MNIST 样例自定义初始化函数添加接收 tf.train.Scaffold 对象参数。
  • Loading branch information
perhapszzy authored Apr 26, 2017
2 parents dcecdf8 + fed3c68 commit 936db43
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 74 deletions.
2 changes: 1 addition & 1 deletion caicloud.tensorflow/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.1
2.0.2
169 changes: 105 additions & 64 deletions caicloud.tensorflow/caicloud/clever/examples/mnist/mnist_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,41 @@

FLAGS = tf.app.flags.FLAGS

# Build model ...
mnist = read_data_sets(FLAGS.data_dir, one_hot=True)

local_step = 0
input_images = None
lables = None
loss = None
optimizer = None
train_op = None
global_step = None
_mnist = read_data_sets(FLAGS.data_dir, one_hot=True)

_local_step = 0
_input_images = None
_labels = None
_loss = None
_train_op = None
_global_step = None
_accuracy = None
_summary_op = None
_summary_writer = None

def model_fn(sync, num_replicas):
# 这些变量在后续的训练操作函数 train_fn() 中会使用到,
# 所以这里使用了 global 变量。
global input_images, loss, labels, optimizer, train_op, accuracy
global mnist, global_step
global _input_images, _loss, _labels, _train_op, _accuracy
global _mnist, _global_step, _summary_op, _summary_writer

# 构建推理模型
input_images = tf.placeholder(tf.float32, [None, 784], name='image')
_input_images = tf.placeholder(tf.float32, [None, 784], name='image')
W = tf.Variable(tf.zeros([784, 10]), name='weights')
tf.summary.histogram("weights", W)
b = tf.Variable(tf.zeros([10]), name='bias')
tf.summary.histogram("bias", b)
logits = tf.matmul(input_images, W) + b
logits = tf.matmul(_input_images, W) + b

global_step = tf.Variable(0, name='global_step', trainable=False)
_global_step = tf.Variable(0, name='global_step', trainable=False)

# Define loss and optimizer
labels = tf.placeholder(tf.float32, [None, 10], name='labels')
_labels = tf.placeholder(tf.float32, [None, 10], name='labels')
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))
loss = tf.reduce_mean(cross_entropy, name='loss')
tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=_labels))
tf.summary.scalar("cross_entropy", cross_entropy)
_loss = tf.reduce_mean(cross_entropy, name='loss')
tf.add_to_collection(tf.GraphKeys.LOSSES, _loss)

# Create optimizer to compute gradient
optimizer = tf.train.AdagradOptimizer(0.01);
Expand All @@ -67,34 +69,45 @@ def model_fn(sync, num_replicas):
total_num_replicas=num_workers,
name="mnist_sync_replicas")

train_op = optimizer.minimize(cross_entropy, global_step=global_step)
_train_op = optimizer.minimize(cross_entropy, global_step=_global_step)

# 自定义计算模型 summary 信息的 Operation,
# 并定义一个 FileWriter 用于保存模型 summary 信息。
# 其中 dist_base.cfg.logdir 是 TaaS 平台上设置的训练日志路径参数。
_summary_op = tf.summary.merge_all()
_summary_writer = tf.summary.FileWriter(dist_base.cfg.logdir)

# Test trained model
correct_prediction = tf.equal(tf.argmax(logits, 1),
tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.argmax(_labels, 1))
_accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def accuracy_evalute_fn(session):
return session.run(accuracy,
return session.run(_accuracy,
feed_dict={
input_images: mnist.validation.images,
labels: mnist.validation.labels})
_input_images: _mnist.validation.images,
_labels: _mnist.validation.labels})

# 定义模型导出配置
model_export_spec = model_exporter.ModelExportSpec(
export_dir=FLAGS.export_dir,
input_tensors={"image": input_images},
input_tensors={"image": _input_images},
output_tensors={"logits": logits})

# 定义模型评测(准确率)的计算方法
model_metric_ops = {
"accuracy": accuracy_evalute_fn
}


# 因为模型中需要计算 tf.summary.scalar(cross_entropy),而该 summary 的计算需要
# feed 设置 _input_images 和 _labels,所以这里将 summary_op 设置成 None,将关闭
# TaaS 的自动计算和保存模型 summary 信息机制。在 train_op 函数中自己来计算并收集
# 模型 Graph 的 summary 信息。
return dist_base.ModelFnHandler(
global_step=global_step,
global_step=_global_step,
optimizer=optimizer,
model_metric_ops = model_metric_ops,
model_export_spec=model_export_spec)
model_metric_ops=model_metric_ops,
model_export_spec=model_export_spec,
summary_op=None)

def gen_init_fn():
"""获取自定义初始化函数。
Expand All @@ -116,75 +129,103 @@ def gen_init_fn():
checkpoint_path = checkpoint_path
print('warm-start from checkpoint {0}'.format(checkpoint_path))

# Create an initial assignment function.
# 定义 tf.train.Saver 会修改 TensorFlow 的 Graph 结构,
# 而当 Base 框架调用自定义初始化函数 init_from_checkpoint 的时候,
# TensorFlow 模型的 Graph 结构已经变成 finalized,不再允许修改 Graph 结构。
# 所以,这个定义必须放在 init_from_checkpoint 函数外面。
saver = tf.train.Saver(tf.trainable_variables())
def InitAssignFn(sess):

def init_from_checkpoint(scaffold, sess):
"""执行自定义初始化的函数。
TaaS 平台会优先从设置的日志保存路径中获取最新的 checkpoint 来 restore 模型参数,
如果日志保存路径中找不到 checkpoint 文件,才会调用本函数来进行模型初始化。
本函数必须接收两个参数:
- scafford: tf.train.Scaffold 对象;
- sess: tf.Session 对象。
"""
saver.restore(sess, checkpoint_path)
print('Accuracy for restored model:')
compute_accuracy(sess)
return InitAssignFn

return init_from_checkpoint

_last_summary_step = 0
def train_fn(session, num_global_step):
global local_step, input_images, labels, accuracy
global mnist, train_op, loss, global_step
global local_step
"""每一轮模型训练操作。"""
global _local_step, _input_images, _labels, _accuracy
global _mnist, _train_op, _loss, _global_step
global _summary_op, _summary_writer, _last_summary_step

start_time = time.time()
local_step += 1
batch_xs, batch_ys = mnist.train.next_batch(100)
feed_dict = {input_images: batch_xs,
labels: batch_ys}
_, loss_value, np_global_step = session.run(
[train_op, loss, global_step],
_local_step += 1
batch_xs, batch_ys = _mnist.train.next_batch(100)
feed_dict = {_input_images: batch_xs,
_labels: batch_ys}
_, loss_value, np_global_step, summary_str = session.run(
[_train_op, _loss, _global_step, _summary_op],
feed_dict=feed_dict)
duration = time.time() - start_time
if local_step % 50 == 0:

if _local_step%50 == 0:
print('Step {0}: loss = {1:0.2f} ({2:0.3f} sec), global step: {3}.'.format(
local_step, loss_value, duration, np_global_step))
if local_step % 1000 == 0:
_local_step, loss_value, duration, np_global_step))

# 每隔固定训练轮数计算保存一次模型 summary 信息。
# 通过 dist_base.cfg.save_summaies_steps 获取在 TaaS 平台上设置的
# "自动保存 summaries 日志间隔"参数值。
if (np_global_step - _last_summary_step >= dist_base.cfg.save_summaries_steps):
_summary_writer.add_summary(summary_str, np_global_step)
_summary_writer.flush()
_last_summary_step = np_global_step

if _local_step%1000 == 0:
print("Accuracy for validation data: {0:0.3f}".format(
session.run(
accuracy,
_accuracy,
feed_dict={
input_images: mnist.validation.images,
labels: mnist.validation.labels})))
_input_images: _mnist.validation.images,
_labels: _mnist.validation.labels})))

return False


def after_train_hook(session):
global _summary_writer
_summary_writer.close()

print("Train done.")
print("Accuracy for test data: {0:0.3f}".format(
session.run(
accuracy,
_accuracy,
feed_dict={
input_images: mnist.test.images,
labels: mnist.test.labels})))
_input_images: _mnist.test.images,
_labels: _mnist.test.labels})))

def compute_accuracy(session):
print("Accuracy:")
print("\tTraining Data: {0:0.3f}".format(
session.run(
accuracy,
_accuracy,
feed_dict={
input_images: mnist.train.images,
labels: mnist.train.labels})))
_input_images: _mnist.train.images,
_labels: _mnist.train.labels})))
print("\tValidation Data: {0:0.3f}".format(
session.run(
accuracy,
_accuracy,
feed_dict={
input_images: mnist.validation.images,
labels: mnist.validation.labels})))
_input_images: _mnist.validation.images,
_labels: _mnist.validation.labels})))
print("\tTest Data: {0:0.3f}".format(
session.run(
accuracy,
_accuracy,
feed_dict={
input_images: mnist.test.images,
labels: mnist.test.labels})))
_input_images: _mnist.test.images,
_labels: _mnist.test.labels})))

if __name__ == '__main__':
distTfRunner = dist_base.DistTensorflowRunner(
model_fn = model_fn,
after_train_hook = after_train_hook,
gen_init_fn = gen_init_fn)
model_fn=model_fn,
after_train_hook=after_train_hook,
gen_init_fn=gen_init_fn)
distTfRunner.run(train_fn)
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rm -rf /tmp/caicloud-dist-tf
rm -rf /tmp/saved_model/mnist

export TF_MAX_STEPS=3000
export TF_LOGDIR=/tmp/mnist-log
export TF_LOGDIR=/tmp/mnist-log1
export TF_SAVE_CHECKPOINTS_SECS=60
export TF_SAVE_SUMMARIES_STEPS=10
python mnist_export.py --data_dir="/tmp/mnist-data" --export_dir="/tmp/saved_model/mnist"
python mnist_export.py --checkpoint_dir=/tmp/mnist-log --data_dir="/tmp/mnist-data" --export_dir="/tmp/saved_model/mnist"
32 changes: 25 additions & 7 deletions caicloud.tensorflow/caicloud/clever/tensorflow/dist_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from caicloud.clever.tensorflow import model_exporter
import os

_USE_DEFAULT = 0

class RunConfig(object):
def __init__(self):
self.is_chief = True
Expand All @@ -41,7 +43,8 @@ def __init__(self,
global_step=None,
optimizer=None,
model_export_spec=None,
model_metric_ops=None):
model_metric_ops=None,
summary_op=_USE_DEFAULT):
"""创建一个 ModelFnHandler 对象。
分布式 TensorFlow 运行器 DistTensorflowRunner 的模型构建方法 model_fn 的返回值
Expand Down Expand Up @@ -73,6 +76,8 @@ def __init__(self,
if (model_metric_ops is not None) and (not isinstance(model_metric_ops, dict)):
raise ValueError('model_export_spec must be a None or dict.')
self._model_metric_ops = model_metric_ops

self._summary_op = summary_op

@property
def global_step(self):
Expand All @@ -90,6 +95,10 @@ def model_export_spec(self):
def model_metric_ops(self):
return self._model_metric_ops

@property
def summary_op(self):
return self._summary_op


class DistTensorflowRunner(object):
"""分布式 TensorFlow 运行器。"""
Expand Down Expand Up @@ -170,6 +179,8 @@ def _call_model_fn(self):
if model_fn_handler.model_export_spec is not None:
self._model_exporter = model_exporter.ModelExporter(model_fn_handler.model_export_spec)

return model_fn_handler


def run(self, train_fn):
"""执行分布式 TensorFlow 训练。
Expand All @@ -186,19 +197,27 @@ def run(self, train_fn):

g = tf.Graph()
with g.as_default():
self._call_model_fn()
model_fn_handler = self._call_model_fn()


saver = tf.train.Saver()
summary_op = tf.summary.merge_all()
summary_op = model_fn_handler.summary_op
if summary_op == _USE_DEFAULT:
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()

init_fn = None
if self._gen_init_fn is not None:
init_fn = self._gen_init_fn()
customed_init_fn = self._gen_init_fn()
def init_fn(sess):
scaffold = tf.train.Scaffold(
init_op = init_op,
saver = saver
)
customed_init_fn(scaffold, sess)

logdir = cfg.logdir
sv = tf.train.Supervisor(
logdir=logdir,
logdir=cfg.logdir,
graph=g,
init_op=init_op,
summary_op=summary_op,
Expand All @@ -214,7 +233,6 @@ def run(self, train_fn):
print("Training begins @ {0}".format(str(datetime.now())))

# Use the session to train the graph.
sess.run(init_op)
step = 0
while not sv.should_stop():
step = sess.run(self._global_step)
Expand Down

0 comments on commit 936db43

Please sign in to comment.