Skip to content

Commit

Permalink
Expose an axis argument for VocabInfo, which allows for warm-starting…
Browse files Browse the repository at this point in the history
… of the second axis of Tensors through tf.train.warm_start. Note that the underlying initializer already has this functionality (for example, for output layers).

PiperOrigin-RevId: 211709879
  • Loading branch information
eddie-zhou authored and tensorflower-gardener committed Sep 5, 2018
1 parent ebf6d25 commit 47b1af2
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 26 deletions.
2 changes: 1 addition & 1 deletion tensorflow/python/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,7 +2056,7 @@ class WarmStartSettings(
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
`tf.estimator.VocabInfo`. The variable names should be "full" variables,
not the names of the partitions. If not explicitly provided, the variable
is assumed to have no vocabulary.
is assumed to have no (changes to) vocabulary.
var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
name of the previously-trained variable in `ckpt_to_initialize_from`. If
not explicitly provided, the name of the variable is assumed to be same
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/training/checkpoint_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def _load_and_remap_matrix_initializer(ckpt_path,
vocab files are the same, and no column remapping is done.
The returned initializer only supports div-partitioning along the row axis. It
does not support partitioning along the column axis or mod-partitioning.
does not support partitioning along the column axis (as this is not common in
practice) or mod-partitioning.
NOTE: When this is used to warm-start variables, client code should use
`tf.lookup.index_table_from_tensor()` like
Expand Down
100 changes: 87 additions & 13 deletions tensorflow/python/training/warm_starting_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class VocabInfo(
"old_vocab",
"old_vocab_size",
"backup_initializer",
"axis",
])):
"""Vocabulary information for warm-starting.
Expand All @@ -62,6 +63,42 @@ class VocabInfo(
backup_initializer: [Optional] A variable initializer used for variables
corresponding to new vocabulary entries and OOV. If not provided, these
entries will be zero-initialized.
axis: [Optional] Denotes what axis the vocabulary corresponds to. The
default, 0, corresponds to the most common use case (embeddings or
linear weights for binary classification / regression). An axis of 1
could be used for warm-starting output layers with class vocabularies.
For example:
embeddings_vocab_info = tf.VocabInfo(
new_vocab='embeddings_vocab',
new_vocab_size=100,
num_oov_buckets=1,
old_vocab='pretrained_embeddings_vocab',
old_vocab_size=10000,
backup_initializer=tf.truncated_normal_initializer(
mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
axis=0)
softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
new_vocab='class_vocab',
new_vocab_size=5,
num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab',
old_vocab_size=8,
backup_initializer=tf.glorot_uniform_initializer(),
axis=1)
softmax_output_layer_bias_vocab_info = tf.VocabInfo(
new_vocab='class_vocab',
new_vocab_size=5,
num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab',
old_vocab_size=8,
backup_initializer=tf.zeros_initializer(),
axis=0)
Currently, only axis=0 and axis=1 are supported.
"""

def __new__(cls,
Expand All @@ -70,7 +107,12 @@ def __new__(cls,
num_oov_buckets,
old_vocab,
old_vocab_size=-1,
backup_initializer=None):
backup_initializer=None,
axis=0):
if axis != 0 and axis != 1:
raise ValueError("The only supported values for the axis argument are 0 "
"and 1. Provided axis: {}".format(axis))

return super(VocabInfo, cls).__new__(
cls,
new_vocab,
Expand All @@ -79,6 +121,7 @@ def __new__(cls,
old_vocab,
old_vocab_size,
backup_initializer,
axis,
)


Expand Down Expand Up @@ -149,7 +192,8 @@ def _warm_start_var_with_vocab(var,
previous_vocab_size=-1,
current_oov_buckets=0,
prev_tensor_name=None,
initializer=None):
initializer=None,
axis=0):
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
Use this method when the `var` is backed by vocabulary. This method stitches
Expand Down Expand Up @@ -180,6 +224,7 @@ def _warm_start_var_with_vocab(var,
None, we lookup tensor with same name as given `var`.
initializer: Variable initializer to be used for missing entries. If None,
missing entries will be zero-initialized.
axis: Axis of the variable that the provided vocabulary corresponds to.
Raises:
ValueError: If required args are not provided.
Expand All @@ -204,6 +249,8 @@ def _warm_start_var_with_vocab(var,
# Assume tensor name remains the same.
prev_tensor_name = _infer_var_name(var)

# TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
for v in var:
v_shape = v.get_shape().as_list()
slice_info = v._get_save_slice_info()
Expand All @@ -213,19 +260,45 @@ def _warm_start_var_with_vocab(var,
full_shape=slice_info.full_shape,
var_offset=slice_info.var_offset)

# TODO(eddz): Support cases where class vocabularies need remapping too.
if axis == 0:
new_row_vocab_size = current_vocab_size
new_col_vocab_size = v_shape[1]
old_row_vocab_size = previous_vocab_size
old_row_vocab_file = prev_vocab_path
new_row_vocab_file = current_vocab_path
old_col_vocab_file = None
new_col_vocab_file = None
num_row_oov_buckets = current_oov_buckets
num_col_oov_buckets = 0
elif axis == 1:
# Note that we must compute this value across all partitions, whereas
# in the axis = 0 case, we can simply use v_shape[1] because we don't
# allow partitioning across axis = 1.
new_row_vocab_size = total_v_first_axis
new_col_vocab_size = current_vocab_size
old_row_vocab_size = -1
old_row_vocab_file = None
new_row_vocab_file = None
old_col_vocab_file = prev_vocab_path
new_col_vocab_file = current_vocab_path
num_row_oov_buckets = 0
num_col_oov_buckets = current_oov_buckets
else:
raise ValueError("The only supported values for the axis argument are 0 "
"and 1. Provided axis: {}".format(axis))

init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
new_row_vocab_size=current_vocab_size,
new_col_vocab_size=v_shape[1],
old_row_vocab_size=previous_vocab_size,
old_row_vocab_file=prev_vocab_path,
new_row_vocab_file=current_vocab_path,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=current_oov_buckets,
num_col_oov_buckets=0,
new_row_vocab_size=new_row_vocab_size,
new_col_vocab_size=new_col_vocab_size,
old_row_vocab_size=old_row_vocab_size,
old_row_vocab_file=old_row_vocab_file,
new_row_vocab_file=new_row_vocab_file,
old_col_vocab_file=old_col_vocab_file,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=num_row_oov_buckets,
num_col_oov_buckets=num_col_oov_buckets,
initializer=initializer)
new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info))
Expand Down Expand Up @@ -374,7 +447,8 @@ def warm_start(ckpt_to_initialize_from,
previous_vocab_size=vocab_info.old_vocab_size,
current_oov_buckets=vocab_info.num_oov_buckets,
prev_tensor_name=prev_var_name,
initializer=vocab_info.backup_initializer)
initializer=vocab_info.backup_initializer,
axis=vocab_info.axis)
else:
# For the special value of vars_to_warm_start = None,
# we only warm-start variables with explicitly specified vocabularies.
Expand Down
140 changes: 129 additions & 11 deletions tensorflow/python/training/warm_starting_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def testWarmStartVar(self):
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
self.assertAllEqual(prev_val, fruit_weights.eval(sess))
self.assertAllClose(prev_val, fruit_weights.eval(sess))

def testWarmStartVarPrevVarPartitioned(self):
_, weights = self._create_prev_run_var(
Expand All @@ -123,7 +123,7 @@ def testWarmStartVarPrevVarPartitioned(self):
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
self.assertAllEqual(prev_val, fruit_weights.eval(sess))
self.assertAllClose(prev_val, fruit_weights.eval(sess))

def testWarmStartVarCurrentVarPartitioned(self):
_, prev_val = self._create_prev_run_var(
Expand All @@ -143,7 +143,7 @@ def testWarmStartVarCurrentVarPartitioned(self):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
self.assertAllEqual(prev_val, new_val)
self.assertAllClose(prev_val, new_val)

def testWarmStartVarBothVarsPartitioned(self):
_, weights = self._create_prev_run_var(
Expand All @@ -170,7 +170,7 @@ def testWarmStartVarBothVarsPartitioned(self):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
self.assertAllEqual(prev_val, new_val)
self.assertAllClose(prev_val, new_val)

def testWarmStartVarWithVocab(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
Expand All @@ -189,9 +189,34 @@ def testWarmStartVarWithVocab(self):
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))

def testWarmStartVarWithColumnVocab(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])

# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
sess.run(variables.global_variables_initializer())
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
[2.3, 2., 0.]], fruit_output_layer.eval(sess))

def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
Expand All @@ -215,7 +240,7 @@ def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
previous_vocab_size=2)
sess.run(variables.global_variables_initializer())
# Old vocabulary limited to ['apple', 'banana'].
self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
fruit_weights.eval(sess))

def testWarmStartVarWithVocabPrevVarPartitioned(self):
Expand All @@ -238,9 +263,36 @@ def testWarmStartVarWithVocabPrevVarPartitioned(self):
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))

def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
shape=[4, 2],
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
partitioner=lambda shape, dtype: [2, 1])

# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
sess.run(variables.global_variables_initializer())
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
[2.3, 2., 0.]], fruit_output_layer.eval(sess))

def testWarmStartVarWithVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
Expand Down Expand Up @@ -269,11 +321,43 @@ def testWarmStartVarWithVocabCurrentVarPartitioned(self):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
self.assertAllEqual([[2.], [1.5], [1.]],
self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
self.assertAllEqual([[0.5], [0.], [0.]],
self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))

def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])

# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
sess.run(variables.global_variables_initializer())
self.assertTrue(
isinstance(fruit_output_layer, variables.PartitionedVariable))
fruit_output_layer_vars = fruit_output_layer._get_variable_list()
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
fruit_output_layer_vars[0].eval(sess))
self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
fruit_output_layer_vars[1].eval(sess))

def testWarmStartVarWithVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
Expand Down Expand Up @@ -301,11 +385,45 @@ def testWarmStartVarWithVocabBothVarsPartitioned(self):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
self.assertAllEqual([[2.], [1.5], [1.]],
self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
self.assertAllEqual([[0.5], [0.], [0.]],
self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))

def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
self._create_prev_run_var(
"fruit_output_layer",
shape=[4, 2],
initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
partitioner=lambda shape, dtype: [2, 1])

# New vocab with elements in reverse order and one new element.
new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
current_vocab_size=3,
prev_ckpt=self.get_temp_dir(),
prev_vocab_path=prev_vocab_path,
axis=1)
sess.run(variables.global_variables_initializer())
self.assertTrue(
isinstance(fruit_output_layer, variables.PartitionedVariable))
fruit_output_layer_vars = fruit_output_layer._get_variable_list()
self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
fruit_output_layer_vars[0].eval(sess))
self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
fruit_output_layer_vars[1].eval(sess))

def testWarmStart_ListOfVariables(self):
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
Expand Down
Loading

0 comments on commit 47b1af2

Please sign in to comment.