Skip to content

Commit

Permalink
Tower-local variable support for DistributionStrategy. Each tower has
Browse files Browse the repository at this point in the history
its own variable, but fetch() and checkpoint apply a reduction to get
a single value.

PiperOrigin-RevId: 190853123
  • Loading branch information
tensorflower-gardener committed Mar 28, 2018
1 parent 830c19c commit 390e19a
Showing 1 changed file with 53 additions and 6 deletions.
59 changes: 53 additions & 6 deletions tensorflow/python/training/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,18 @@ def __exit__(self, exception_type, exception_value, traceback):


def get_tower_context():
"""Returns the current TowerContext or None.
"""Returns the current TowerContext or None if in a cross-tower context.
Note that execution:
1. starts in the default (single-tower) tower context;
2. switches to cross-tower context when entering a
`with DistributionStrategy.scope():` block;
1. starts in the default (single-tower) tower context (this function
will return the default TowerContext object);
2. switches to cross-tower context (in which case this will return
None) when entering a `with DistributionStrategy.scope():` block;
3. switches to a (non-default) tower context inside
`call_for_each_tower(fn, ...)`;
4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
inside `merge_fn` you are back in the cross-tower context.
inside `merge_fn` you are back in the cross-tower context (and again
this function will return None).
Note that you can also go directly from step 1 to 4 to switch to a
cross-tower context for the default `DistributionStrategy`. You may
Expand Down Expand Up @@ -188,6 +190,9 @@ def get_cross_tower_context():
def get_distribution_strategy():
"""Returns the current `DistributionStrategy` object.
Prefer to use `get_tower_context()` or `get_cross_tower_context()`
instead when possible.
Returns:
A `DistributionStrategy` object. Inside a
`with distribution_strategy.scope()` block, it returns
Expand Down Expand Up @@ -526,7 +531,6 @@ class DistributionStrategy(object):
# TODO(josh11b): ClusterSpec/ClusterResolver
# TODO(josh11b): Partitioned computations, state; sharding
# TODO(josh11b): Model parallelism: "towers" with multiple devices; shuffling
# TODO(josh11b): Tower-local variables
# TODO(josh11b): List of towers with their worker and parameter devices
# (where the parameter devices may overlap in the ps case).

Expand Down Expand Up @@ -556,6 +560,43 @@ def _create_variable(self, next_creator, *args, **kwargs):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")

def tower_local_var_scope(self, reduce_method):
"""Inside this scope, new variables will not be mirrored.
There will still be one component variable per tower, but there is
no requirement that they stay in sync. Instead, when saving them
or calling `fetch()`, we use the value that results when calling
`reduce()` on all the towers' variables.
Note: tower-local implies not trainable. Instead, it is expected
that each tower will directly update (using `assign_add()` or
whatever) its local variable instance but only the aggregated
value (accessible using `fetch()`) will be exported from the
model. When it is acceptable to only aggregate on export, we
greatly reduce communication overhead by using tower-local
variables.
Note: All component variables will be initialized to the same
value, using the initialization expression from the first tower.
The values will match even if the initialization expression uses
random numbers.
Args:
reduce_method: String used as a `method_string` to `reduce()`
to get the value to save when checkpointing.
Returns:
A context manager.
"""
def create_tower_local_variable(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
kwargs["use_resource"] = True
kwargs["tower_local_reduce_method"] = reduce_method
return next_creator(*args, **kwargs)

_require_distribution_strategy_scope(self)
return variable_scope.variable_creator_scope(create_tower_local_variable)

def colocate_vars_with(self, colocate_with_variable):
"""Scope that controls which devices variables will be created on.
Expand Down Expand Up @@ -984,6 +1025,10 @@ def _merge_call(self, merge_fn, *args, **kwargs):
finally:
_pop_per_thread_mode()

def tower_local_var_scope(self, reduce_method):
"""Alias for distribution_strategy.tower_local_var_scope()."""
return self._distribution_strategy.tower_local_var_scope(reduce_method)

@property
def is_single_tower(self):
"""Returns whether there is a single tower or multiple."""
Expand Down Expand Up @@ -1030,6 +1075,8 @@ def scope(self):

def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
if kwargs.pop("tower_local_reduce_method", None) is not None:
kwargs["trainable"] = False
return next_creator(*args, **kwargs)

return _CurrentDistributionContext(
Expand Down

0 comments on commit 390e19a

Please sign in to comment.