Skip to content

Commit

Permalink
TaaS 模型训练的运行参数通过环境变量的传递。
Browse files Browse the repository at this point in the history
TaaS 平台模型训练的 Base 框架调整后,用户的代码文件作为主入口文件。如果用户使用了 argparse 来
解析命令参数的话,将会出现无法处理 TaaS 运行参数,于是将 TaaS 运行参数通过环境变量的方式提供。
  • Loading branch information
lienhua34 committed Apr 20, 2017
1 parent 3a2c199 commit 1182134
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*~
*.pyc
.DS_Store

2 changes: 1 addition & 1 deletion caicloud.tensorflow/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0
2.0.1
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@
rm -rf /tmp/caicloud-dist-tf
rm -rf /tmp/saved_model/mnist

python mnist_export.py --max_steps=2000 --data_dir="/tmp/mnist-data" --export_dir="/tmp/saved_model/mnist"
export TF_MAX_STEPS=3000
export TF_LOGDIR=/tmp/mnist-log
export TF_SAVE_CHECKPOINTS_SECS=60
export TF_SAVE_SUMMARIES_STEPS=10
python mnist_export.py --data_dir="/tmp/mnist-data" --export_dir="/tmp/saved_model/mnist"
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
rm -rf /tmp/ptb_saved_model
rm -rf /tmp/ptb

export TF_MAX_STEPS=10000
export TF_LOGDIR=/tmp/ptb
export TF_SAVE_CHECKPOINTS_SECS=60
export TF_SAVE_SUMMARIES_STEPS=10
python ptb_caicloud_taas.py \
--max_steps=10000 \
--save_checkpoints_secs=3 \
--save_summaries_secs=3 \
--logdir=/tmp/ptb \
--data_path=./simple-examples/data \
--save_path=/tmp/ptb_saved_model
30 changes: 15 additions & 15 deletions caicloud.tensorflow/caicloud/clever/tensorflow/dist_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
from datetime import datetime
import tensorflow as tf
from caicloud.clever.tensorflow import model_exporter
import os

tf.app.flags.DEFINE_integer("max_steps",
1,
"maximum train steps.")
tf.app.flags.DEFINE_string("logdir",
"/tmp/caicloud-dist-tf",
"saves checkpoints and summaries directory path.")
tf.app.flags.DEFINE_integer("save_checkpoints_secs", 60,
"save checkpoints after every special seconds")
tf.app.flags.DEFINE_integer("save_summaries_secs", 120,
"save summaries after every special seconds")
class RunConfig(object):
def __init__(self):
self.is_chief = True
self.use_gpu = False
self.sync = False
self.max_steps = int(os.getenv("TF_MAX_STEPS", "1"))
self.logdir = os.getenv("TF_LOGDIR", "/tmp/caicloud-dist-tf")
self.save_checkpoints_secs = int(os.getenv("TF_SAVE_CHECKPOINTS_SECS", "600"))
self.save_summaries_steps = int(os.getenv("TF_SAVE_SUMMARIES_STEPS", "100"))

FLAGS = tf.app.flags.FLAGS
cfg = RunConfig()

class ModelFnHandler(object):
"""model_fn 函数返回的一些模型信息。"""
Expand Down Expand Up @@ -196,16 +196,16 @@ def run(self, train_fn):
if self._gen_init_fn is not None:
init_fn = self._gen_init_fn()

logdir = FLAGS.logdir
logdir = cfg.logdir
sv = tf.train.Supervisor(
logdir=logdir,
graph=g,
init_op=init_op,
summary_op=summary_op,
saver=saver,
global_step=self._global_step,
save_model_secs=FLAGS.save_checkpoints_secs,
save_summaries_secs=FLAGS.save_summaries_secs,
save_model_secs=cfg.save_checkpoints_secs,
save_summaries_secs=cfg.save_summaries_steps,
init_fn=init_fn)
# Get a TensorFlow session managed by the supervisor.
sess = sv.prepare_or_wait_for_session('')
Expand All @@ -218,7 +218,7 @@ def run(self, train_fn):
step = 0
while not sv.should_stop():
step = sess.run(self._global_step)
if step > FLAGS.max_steps:
if step > cfg.max_steps:
break
should_stop = train_fn(sess, step)
if should_stop:
Expand Down

0 comments on commit 1182134

Please sign in to comment.