Skip to content

Commit

Permalink
make add_moving_summary use local variables, so they are not broadcasted
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Aug 25, 2020
1 parent 2d661d6 commit 02e53f7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
3 changes: 2 additions & 1 deletion tensorpack/tfutils/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,11 @@ def add_moving_summary(*args, **kwargs):
assert x.get_shape().ndims == 0, \
"add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape())

from ..graph_builder.utils import override_to_local_variable
ema_ops = []
for c in args:
name = re.sub('tower[0-9]+/', '', c.op.name)
with tf.name_scope(None):
with tf.name_scope(None), override_to_local_variable(True):
if not c.dtype.is_floating:
c = tf.cast(c, tf.float32)
# assign_moving_average creates variables with op names, therefore clear ns first.
Expand Down
8 changes: 5 additions & 3 deletions tensorpack/train/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def _setup_graph(self, input, get_cost_fn, get_opt_fn):
return [cb]

def broadcast(self, _):
logger.info("Running broadcast ...")
logger.info("Broadcasting {} global variables ...".format(self._num_global_variables))
# the op will be created in initialize()
self.sess.run(self._broadcast_op)

Expand All @@ -483,6 +483,7 @@ def initialize(self, session_creator, session_init):
# broadcast_op should be the last setup_graph: it needs to be created
# "right before" the graph is finalized,
# because it needs to capture all the variables (which may be created by callbacks).
self._num_global_variables = len(tf.global_variables())
self._broadcast_op = self.hvd.broadcast_global_variables(0)

# it's important that our NewSessionCreator does not finalize the graph
Expand All @@ -504,9 +505,10 @@ def initialize(self, session_creator, session_init):
# 1. a allgather helper to concat strings
# 2. check variables on each rank match each other, print warnings, and broadcast the common set.
if self.is_chief:
logger.info("Broadcasting initialized variables ...")
logger.info("Broadcasting initialization of {} global variables ...".format(self._num_global_variables))
else:
logger.info("Rank {} waiting for initialization broadcasting ...".format(self._rank))
logger.info("Rank {} waiting for initialization of {} variables ...".format(
self._rank, self._num_global_variables))
self.sess.run(self._broadcast_op)


Expand Down

0 comments on commit 02e53f7

Please sign in to comment.