Skip to content

Commit

Permalink
Adding checkpoints library for listing checkpoint folder, loading ten…
Browse files Browse the repository at this point in the history
…sors and loading tensors into existing variables.

Example usage for init_from_checkpoints:

    # Create variables.
    with tf.variable_scope('test'):
      m = tf.get_variable('my_var')
    with tf.variable_scope('test2'):
      m = tf.get_variable('my_var')
    ...
    # Specify which variables to initialize from checkpoint.
    init_from_checkpoint(checkpoint_dir, {
      'test/my_var': 'some_var',
      'test2/', 'some_scope/'})
    ...
    # Initialize variables as usual.
    session.run(tf.get_all_variables())
Change: 121859476
  • Loading branch information
ilblackdragon authored and tensorflower-gardener committed May 9, 2016
1 parent 6958623 commit 1c76475
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 3 deletions.
13 changes: 13 additions & 0 deletions tensorflow/contrib/learn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,19 @@ py_test(
],
)

py_test(
name = "checkpoints_test",
size = "small",
srcs = ["python/learn/utils/checkpoints_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)

filegroup(
name = "all_files",
srcs = glob(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/learn/python/learn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorflow.contrib.learn.python.learn import preprocessing
from tensorflow.contrib.learn.python.learn import models
from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn import utils

from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
from tensorflow.contrib.learn.python.learn.graph_actions import infer
Expand Down
21 changes: 21 additions & 0 deletions tensorflow/contrib/learn/python/learn/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Learn Utils."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.learn.python.learn.utils import checkpoints
211 changes: 211 additions & 0 deletions tensorflow/contrib/learn/python/learn/utils/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools to work with checkpoints."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six

from tensorflow.python.framework import dtypes
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver
from tensorflow.python.training import training as train


def load_checkpoint(filepattern):
"""Returns CheckpointReader for latest checkpoint.
Args:
filepattern: Directory with checkpoints file or path to checkpoint.
Returns:
`CheckpointReader` object.
Raises:
ValueError: if checkpoint_dir doesn't have 'checkpoint' file or checkpoints.
"""
if gfile.IsDirectory(filepattern):
filename = saver.latest_checkpoint(filepattern)
if filename is None:
raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
"given directory %s" % filepattern)
return train.NewCheckpointReader(filename)
return train.NewCheckpointReader(filepattern)


def load_variable(checkpoint_dir, name):
"""Returns a Tensor with the contents of the given variable in the checkpoint.
Args:
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
name: Name of the tensor to return.
Returns:
`Tensor` object.
"""
reader = load_checkpoint(checkpoint_dir)
return reader.get_tensor(name)


def list_variables(checkpoint_dir):
"""Returns list of all variables in the latest checkpoint.
Args:
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
Returns:
List of tuples `(name, shape)`.
"""
reader = load_checkpoint(checkpoint_dir)
variable_map = reader.get_variable_to_shape_map()
names = sorted(variable_map.keys())
result = []
for name in names:
result.append((name, variable_map[name]))
return result


def _checkpoint_initializer(variable, checkpoint_reader, tensor_name):
"""Assigns variable to value that will be loaded from checkpoint's tensor.
Args:
variable: `Variable` object.
checkpoint_reader: `CheckpointReader` object.
tensor_name: Name of the `Tensor` to load from checkpoint reader.
Returns:
`Tensor` that returns value of `tensor_name` in checkpoint.
Raises:
ValueError: if shape or dtype of `variable` doesn't match with Tensor in
checkpoint.
"""
# Currently to avoid putting the whole tensor into the graph, this adds a
# py_func function to the graph, that will return actual value.
# TODO(ipolosukhin): Rewrite this as C++ op, that loads checkpoint at time.
tensor = checkpoint_reader.get_tensor(tensor_name)
def _tensor():
return tensor
if not variable.get_shape().is_compatible_with(tensor.shape):
raise ValueError(
"Shape of variable %s (%s) doesn't match with shape of "
"tensor %s (%s) from checkpoint reader." % (
variable.name, str(variable.get_shape()),
tensor_name, str(tensor.shape)
))
if not dtypes.as_dtype(tensor.dtype).is_compatible_with(variable.dtype):
raise ValueError(
"DType of variable %s (%s) doesn't match with dtype of "
"tensor %s (%s) from checkpoint reader." % (
variable.name, str(variable.dtype),
tensor_name, str(dtypes.as_dtype(tensor.dtype))
))
return state_ops.assign(
variable, script_ops.py_func(_tensor, [], [tensor.dtype])[0])


def init_from_checkpoint(checkpoint_dir, assignment_map):
"""Using assingment map initializes current variables with loaded tensors.
Note: This overrides default initialization ops of specified variables and
redefines dtype.
Assignment map supports next syntax:
`'scope_name/': 'checkpoint_scope_name/'` - will load all variables in
current `scope_name` from `checkpoint_scope_name` with matching variable
names.
`'scope_name/variable_name': 'checkpoint_scope_name/some_other_variable'` -
will initalize `scope_name/variable_name` variable
from `checkpoint_scope_name/some_other_variable`.
Example:
```python
# Create variables.
with tf.variable_scope('test'):
m = tf.get_variable('my_var')
with tf.variable_scope('test2'):
m = tf.get_variable('my_var')
...
# Specify which variables to intialize from checkpoint.
init_from_checkpoint(checkpoint_dir, {
'test/my_var': 'some_var',
'test2/', 'some_scope/'})
...
# Initialize variables as usual.
session.run(tf.get_all_variables())
```
Args:
checkpoint_dir: Directory with checkpoints file or path to checkpoint.
assignment_map: Dict, where keys are names of current variables
(in default graph) and values are names of the variables
in the checkpoint.
Raises:
tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
ValueError: If missing variables in current graph.
"""
reader = load_checkpoint(checkpoint_dir)
variable_map = reader.get_variable_to_shape_map()
for current_name, tensor_name in six.iteritems(assignment_map):
scopes = ""
if "/" in current_name:
scopes = current_name[:current_name.rindex("/")]
current_name = current_name[current_name.rindex("/") + 1:]
if current_name:
# If 1 to 1 mapping was provided, find variable in the scope.
if tensor_name not in variable_map:
raise ValueError("Tensor %s is not found in %s checkpoint" % (
tensor_name, checkpoint_dir
))
with vs.variable_scope(scopes, reuse=True):
var = vs.get_variable(current_name)
var._initializer_op = _checkpoint_initializer(var, reader, tensor_name) # pylint: disable=protected-access
logging.info("Initialize variable %s from checkpoint %s with %s" % (
var.name, checkpoint_dir, tensor_name
))
else:
if not tensor_name.endswith("/"):
raise ValueError(
"Assignment map with scope only name (%s) "
"should map to scope only (%s). "
"Should be 'scope/': 'other_scope/'." % (
scopes, tensor_name
))
# If scope to scope mapping was provided, find all variables in the scope.
# TODO(ipolosukhin): Refactor variable_scope module to provide nicer APIs.
var_scope = vs._get_default_variable_store() # pylint: disable=protected-access
for var_name in var_scope._vars: # pylint: disable=protected-access
if var_name.startswith(scopes):
# Lookup name with specified prefix and suffix from current variable.
full_tensor_name = tensor_name + var_name[len(scopes) + 1:]
if full_tensor_name not in variable_map:
raise ValueError(
"Tensor %s (%s in %s) is not found in %s checkpoint" % (
full_tensor_name, var_name[len(scopes) + 1:], tensor_name,
checkpoint_dir
))
var = var_scope._vars[var_name] # pylint: disable=protected-access
var._initializer_op = _checkpoint_initializer( # pylint: disable=protected-access
var, reader, full_tensor_name)
logging.info("Initialize variable %s from checkpoint %s with %s" % (
var_name, checkpoint_dir, tensor_name
))
Loading

0 comments on commit 1c76475

Please sign in to comment.