Skip to content

Commit

Permalink
Merge pull request tensorflow#4835 from yifeif/branch_135521241
Browse files Browse the repository at this point in the history
Branch 135521241
  • Loading branch information
yifeif authored Oct 8, 2016
2 parents b90ddd2 + 1365cbc commit 0a4f5b6
Show file tree
Hide file tree
Showing 154 changed files with 8,095 additions and 1,901 deletions.
1 change: 1 addition & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ filegroup(
"//tensorflow/contrib/slim:all_files",
"//tensorflow/contrib/slim/python/slim/data:all_files",
"//tensorflow/contrib/slim/python/slim/nets:all_files",
"//tensorflow/contrib/specs:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
"//tensorflow/contrib/tensorboard:all_files",
Expand Down
11 changes: 8 additions & 3 deletions tensorflow/c/checkpoint_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ void CheckpointReader::GetTensor(
if (reader_ != nullptr) {
status = reader_->GetTensor(name, out_tensor);
} else {
std::unique_ptr<Tensor> tensor(new Tensor);
status = v2_reader_->Lookup(name, tensor.get());
if (status.ok()) std::swap(*out_tensor, tensor);
tensorflow::DataType dtype;
tensorflow::TensorShape shape;
status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
if (status.ok()) {
out_tensor->reset(new Tensor(dtype, shape));
status = v2_reader_->Lookup(name, out_tensor->get());
if (!status.ok()) out_tensor->reset();
}
}
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/cc/saved_model/loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ Status Restore(const RunOptions& run_options, const string& export_dir,

} // namespace

Status LoadSavedModel(const string& export_dir,
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
const SessionOptions& session_options,
const RunOptions& run_options,
SavedModelBundle* const bundle) {
if (!MaybeSavedModelDirectory(export_dir)) {
return Status(error::Code::NOT_FOUND,
Expand Down
15 changes: 12 additions & 3 deletions tensorflow/cc/saved_model/loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,25 @@ namespace tensorflow {
struct SavedModelBundle {
std::unique_ptr<Session> session;
MetaGraphDef meta_graph_def;

// A TensorFlow Session does not Close itself on destruction. To avoid
// resource leaks, we explicitly call Close on Sessions that we create.
~SavedModelBundle() {
if (session) {
session->Close();
}
}

SavedModelBundle() = default;
};

// Loads a SavedModel from the specified export directory. The meta graph def to
// be loaded is identified by the supplied tags, corresponding exactly to the
// set of tags used at SavedModel build time. Returns a SavedModel bundle with a
// session and the requested meta graph def, if found.
Status LoadSavedModel(const string& export_dir,
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
const SessionOptions& session_options,
const RunOptions& run_options,
SavedModelBundle* const bundle);

// Checks whether the provided directory could contain a SavedModel. Note that
Expand Down
42 changes: 30 additions & 12 deletions tensorflow/cc/saved_model/loader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,33 @@ class LoaderTest : public ::testing::Test {
}
};

// Test for resource leaks related to TensorFlow session closing requirements
// when loading and unloading large numbers of SavedModelBundles.
// TODO(sukritiramesh): Increase run iterations and move outside of the test
// suite.
TEST_F(LoaderTest, ResourceLeakTest) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;

const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPb);
for (int i = 0; i < 100; ++i) {
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(bundle);
}
}

TEST_F(LoaderTest, TagMatch) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;

const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPb);
TF_ASSERT_OK(LoadSavedModel(export_dir, {kSavedModelTagServe},
session_options, run_options, &bundle));
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(bundle);
}

Expand All @@ -74,8 +92,8 @@ TEST_F(LoaderTest, NoTagMatch) {

const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPb);
Status st = LoadSavedModel(export_dir, {"missing-tag"}, session_options,
run_options, &bundle);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{"missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(
StringPiece(st.error_message())
Expand All @@ -90,8 +108,8 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {

const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPb);
Status st = LoadSavedModel(export_dir, {kSavedModelTagServe, "missing-tag"},
session_options, run_options, &bundle);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe, "missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(
StringPiece(st.error_message())
Expand All @@ -106,8 +124,8 @@ TEST_F(LoaderTest, PbtxtFormat) {

const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
TF_ASSERT_OK(LoadSavedModel(export_dir, {kSavedModelTagServe},
session_options, run_options, &bundle));
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(bundle);
}

Expand All @@ -118,8 +136,8 @@ TEST_F(LoaderTest, ShardedVariables) {

const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
TF_ASSERT_OK(LoadSavedModel(export_dir, {kSavedModelTagServe},
session_options, run_options, &bundle));
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(bundle);
}

Expand All @@ -130,8 +148,8 @@ TEST_F(LoaderTest, InvalidExportPath) {

const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
Status st = LoadSavedModel(export_dir, {kSavedModelTagServe}, session_options,
run_options, &bundle);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
}

Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ py_test(
size = "small",
srcs = ["python/ops/variables_test.py"],
srcs_version = "PY2AND3",
tags = ["manual"],
deps = ["//tensorflow:tensorflow_py"],
)

Expand Down
35 changes: 35 additions & 0 deletions tensorflow/contrib/layers/python/layers/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,41 @@ def create_feature_spec_for_parsing(feature_columns):
return features_config


def _create_sequence_feature_spec_for_parsing(sequence_feature_columns,
allow_missing_by_default=False):
"""Prepares a feature spec for parsing `tf.SequenceExample`s.
Args:
sequence_feature_columns: an iterable containing all the feature columns.
All items should be instances of classes derived from `_FeatureColumn`.
allow_missing_by_default: whether to set `allow_missing=True` by default for
`FixedLenSequenceFeature`s.
Returns:
A dict mapping feature keys to `FixedLenSequenceFeature` or `VarLenFeature`.
"""
feature_spec = create_feature_spec_for_parsing(sequence_feature_columns)
sequence_feature_spec = {}
for key, feature in feature_spec.items():
if isinstance(feature, parsing_ops.VarLenFeature):
sequence_feature = feature
elif isinstance(feature, parsing_ops.FixedLenFeature):
default_is_set = feature.default_value is not None
if default_is_set:
logging.warning(
'Found default value {} for feature "{}". Ignoring this value and '
'setting `allow_missing=True` instead.'.
format(feature.default_value, key))
sequence_feature = parsing_ops.FixedLenSequenceFeature(
shape=feature.shape,
dtype=feature.dtype,
allow_missing=(allow_missing_by_default or default_is_set))
else:
raise TypeError(
"Unsupported feature type: {}".format(type(feature).__name__))
sequence_feature_spec[key] = sequence_feature
return sequence_feature_spec


def make_place_holder_tensors_for_base_features(feature_columns):
"""Returns placeholder tensors for inference.
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/contrib/layers/python/layers/feature_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os

import tensorflow as tf
import tensorflow.contrib.layers.python.layers.feature_column as fc


class FeatureColumnTest(tf.test.TestCase):
Expand Down Expand Up @@ -410,6 +411,40 @@ def testCreateFeatureSpec_RealValuedColumnWithDefaultValue(self):
tf.FixedLenFeature([3], dtype=tf.float32,
default_value=[1., 0., 6.])}, config)

def testCreateSequenceFeatureSpec(self):
sparse_col = tf.contrib.layers.sparse_column_with_hash_bucket(
"sparse_column", hash_bucket_size=100)
embedding_col = tf.contrib.layers.embedding_column(
tf.contrib.layers.sparse_column_with_hash_bucket(
"sparse_column_for_embedding",
hash_bucket_size=10),
dimension=4)
sparse_id_col = tf.contrib.layers.sparse_column_with_keys(
"id_column", ["marlo", "omar", "stringer"])
weighted_id_col = tf.contrib.layers.weighted_sparse_column(
sparse_id_col, "id_weights_column")
real_valued_col1 = tf.contrib.layers.real_valued_column(
"real_valued_column", dimension=2)
real_valued_col2 = tf.contrib.layers.real_valued_column(
"real_valued_default_column", dimension=5, default_value=3.0)

feature_columns = set([sparse_col, embedding_col, weighted_id_col,
real_valued_col1, real_valued_col2])

feature_spec = fc._create_sequence_feature_spec_for_parsing(feature_columns)

expected_feature_spec = {
"sparse_column": tf.VarLenFeature(tf.string),
"sparse_column_for_embedding": tf.VarLenFeature(tf.string),
"id_column": tf.VarLenFeature(tf.string),
"id_weights_column": tf.VarLenFeature(tf.float32),
"real_valued_column": tf.FixedLenSequenceFeature(
shape=[2], dtype=tf.float32, allow_missing=False),
"real_valued_default_column": tf.FixedLenSequenceFeature(
shape=[5], dtype=tf.float32, allow_missing=True)}

self.assertDictEqual(expected_feature_spec, feature_spec)

def testMakePlaceHolderTensorsForBaseFeatures(self):
sparse_col = tf.contrib.layers.sparse_column_with_hash_bucket(
"sparse_column", hash_bucket_size=100)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging


def _assert_float32(tensors):
Expand Down Expand Up @@ -209,12 +210,17 @@ def _get_train_ops(self, features, targets):
Returns:
Tuple of train `Operation` and loss `Tensor`.
"""
features, _, spec = data_ops.ParseDataTensorOrDict(features)
features, _, weights, spec = data_ops.ParseDataTensorOrDict(features)
labels = data_ops.ParseLabelTensorOrDict(targets)
features, labels = self._feature_engineering_fn(features, labels)
_assert_float32(features)
_assert_float32(labels)

if weights is not None:
if 'input_weights' in self.training_args:
logging.warning('Replacing input_weights in training_args.')
self.training_args['input_weights'] = weights

graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner,
**self.construction_args)
Expand All @@ -237,7 +243,7 @@ def _get_predict_ops(self, features):
graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner, training=False,
**self.construction_args)
features, keys, spec = data_ops.ParseDataTensorOrDict(features)
features, keys, _, spec = data_ops.ParseDataTensorOrDict(features)
features, _ = self._feature_engineering_fn(features, None)
_assert_float32(features)
output_dict = {
Expand All @@ -248,7 +254,7 @@ def _get_predict_ops(self, features):
return output_dict

def _get_eval_ops(self, features, targets, metrics):
features, _, spec = data_ops.ParseDataTensorOrDict(features)
features, _, _, spec = data_ops.ParseDataTensorOrDict(features)
labels = data_ops.ParseLabelTensorOrDict(targets)
features, labels = self._feature_engineering_fn(features, labels)
_assert_float32(features)
Expand Down
18 changes: 8 additions & 10 deletions tensorflow/contrib/learn/python/learn/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,15 @@ def evaluate(
steps: Number of steps for which to evaluate model. If `None`, evaluate
until `x` is consumed or `input_fn` raises an end-of-input exception.
See "Stop conditions" above for specifics.
metrics: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
metrics: Dict of metrics to run. If None, the default metric functions
are used; if {}, no metrics are used. Otherwise, `metrics` should map
friendly names for the metric to a `MetricSpec` object defining which
model outputs to evaluate against which targets with which metric
function.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../metrics/python/metrics/ops/streaming_metrics.py.
Metric ops should support streaming, e.g., returning `update_op` and
`value` tensors. For example, see the options defined in
`../../../metrics/python/ops/metrics_ops.py`.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
Expand Down
2 changes: 0 additions & 2 deletions tensorflow/contrib/learn/python/learn/graph_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import coordinator
Expand Down Expand Up @@ -88,7 +87,6 @@ def _make_saver(graph, keep_checkpoint_max=5):

def _restore_from_checkpoint(session, graph, checkpoint_path, saver=None):
logging.info('Loading model from checkpoint: %s.', checkpoint_path)
assert gfile.Glob(checkpoint_path)
saver = saver or _make_saver(graph)
if saver:
saver.restore(session, checkpoint_path)
Expand Down
Loading

0 comments on commit 0a4f5b6

Please sign in to comment.