Skip to content

Commit

Permalink
merge fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewharp committed Nov 4, 2017
2 parents 4e75ae1 + a0671c4 commit ab1ca70
Show file tree
Hide file tree
Showing 364 changed files with 15,953 additions and 8,290 deletions.
1 change: 1 addition & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# NEED OWNER: tensorflow/contrib/avro/*
#tensorflow/contrib/batching/* @alextp @chrisolston
#tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon
#tensorflow/contrib/boosted_trees/* @sshrdp @yk5 @nataliaponomareva
#tensorflow/contrib/cmake/* @mrry @benoitsteiner
#tensorflow/contrib/copy_graph/* @tucker @poxvoculi
#tensorflow/contrib/crf/* @kentonl
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ config_setting(
visibility = ["//visibility:public"],
)

# Make a dummy rule that we can chaqnge "default" in select statements to.
# to disable dependencies in copybara.
config_setting(
name = "dummy_disabled_internal",
values = {"define": "with_dummy_disabled_internal=true"},
visibility = ["//visibility:public"],
)

package_group(
name = "internal",
packages = [
Expand Down Expand Up @@ -413,6 +421,7 @@ filegroup(
"//tensorflow/contrib/makefile:all_files",
"//tensorflow/contrib/meta_graph_transform:all_files",
"//tensorflow/contrib/metrics:all_files",
"//tensorflow/contrib/model_pruning:all_files",
"//tensorflow/contrib/mpi_collectives:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/nearest_neighbor:all_files",
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/c/python_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
graph->graph.AddControlEdge(&input->node, &op->node);
}

void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Buffer* attr_value_proto, TF_Status* status) {
AttrValue attr_val;
if (!attr_val.ParseFromArray(attr_value_proto->data,
attr_value_proto->length)) {
status->status =
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
return;
}

mutex_lock l(graph->mu);
op->node.AddAttr(attr_name, attr_val);
}

void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
mutex_lock l(graph->mu);
op->node.set_requested_device(device);
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/c/python_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ namespace tensorflow {

void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);

// Changes an attr value in the node_def Protocol Buffer and sets a status upon
// completion.
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Buffer* attr_value_proto, TF_Status* status);

void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);

void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/cc/ops/op_gen_overrides.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ op { name: "Reverse" skip: true }
op { name: "ReverseV2" rename_to: "Reverse" }
op { name: "Split" input_rename: { from: "split_dim" to: "axis" } }
op { name: "SplitV" input_rename: { from: "split_dim" to: "axis" } }
op { name: "Squeeze" input_rename: { from: "squeeze_dims" to: "axis" } }
op { name: "Squeeze" attr_rename: { from: "squeeze_dims" to: "axis" } }
op { name: "Pack" rename_to: "Stack" }
op { name: "Unpack" rename_to: "Unstack" }
op { name: "Select" rename_to: "Where3" input_rename: { from: "t" to: "x" } input_rename: { from: "e" to: "y" } }
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ tf_xla_py_test(
],
)

tf_xla_py_test(
name = "categorical_op_test",
size = "small",
srcs = ["categorical_op_test.py"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
)

tf_xla_py_test(
name = "clustering_test",
size = "small",
Expand Down
135 changes: 135 additions & 0 deletions tensorflow/compiler/tests/categorical_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for multinomial generation ops in the XLA JIT compiler."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import numpy as np

from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import googletest


# TODO(srvasude): Merge this with
# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py.
class CategoricalTest(XLATestCase):
"""Test cases for random-number generating operators."""

def _chi2(self, expected, actual):
"""Returns Chi2 GOF statistic."""
actual = np.asarray(actual)
expected = np.asarray(expected)
diff = actual - expected
chi2 = np.sum(diff * diff / expected)
return chi2

def _do_sampling(self, logits, num_samples):
"""Categorical samples from given input.
Args:
logits: Numpy ndarray of shape [batch_size, num_classes].
num_samples: Int; number of samples to draw.
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
with self.test_session() as sess, self.test_scope():
random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples)
d = sess.run(op)

batch_size, num_classes = logits.shape
freqs_mat = []
for i in range(batch_size):
cnts = dict(collections.Counter(d[i, :]))

# Requires drawn class labels be in range.
self.assertLess(max(cnts.keys()), num_classes)
self.assertGreaterEqual(min(cnts.keys()), 0)

freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0)
for k in range(num_classes)]
freqs_mat.append(freqs)

return freqs_mat

def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
with self.test_session() as sess:
with self.test_scope():
x = rng(dtype)

# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
y = sess.run(x)
z = sess.run(x)
w = sess.run(x)

# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
self.assertTrue((not np.array_equal(y, z)) or
(not np.array_equal(z, w)) or
(not np.array_equal(y, w)))

def testCategoricalIsNotConstant(self):
def rng(unused_dtype):
return random_ops.multinomial([[1., 1., 1.]], 10)

dtype = dtypes.float32
self._testRngIsNotConstant(rng, dtype)

def testCategoricalIsInRange(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.test_session() as sess:
with self.test_scope():
x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000)
y = sess.run(x)
self.assertTrue((y >= 0).sum() == 1000)
self.assertTrue((y < 20).sum() == 1000)

def testSamplingCorrectness(self):
np.random.seed(1618) # Make it reproducible.
num_samples = 21000

rand_probs = np.random.dirichlet([1., 1., 2., 3.])
rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3) # batched
for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]:
probs = np.asarray(probs)
if len(probs.shape) == 1:
probs = probs.reshape(1, probs.size) # singleton batch

logits = np.log(probs).astype(np.float32)
freqs = self._do_sampling(logits, num_samples)

# the test here is similar to
# python/kernel_tests/random/multinomial_op_test.py
# Note that df >= 1 in all these cases. Choosing a cutoff of 1e-3
# corresponds to an alpha value of 2.5% for df = 1, and smaller for larger
# df.
chi2 = self._chi2(probs, freqs)
self.assertLess(chi2, 1e-3)


if __name__ == '__main__':
googletest.main()
2 changes: 1 addition & 1 deletion tensorflow/compiler/tests/gather_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def testScalar1D(self):
with self.test_session() as session, self.test_scope():
data = np.array([0, 1, 2, 3, 7, 5])
for dtype in self.all_tf_types:
for indices in 4, [1, 2, 2, 4, 5]:
for indices in 4, [4], [1, 2, 2, 4, 5]:
params_np = self._buildParams(data, dtype)
params = array_ops.placeholder(dtype=dtype)
indices_tf = constant_op.constant(indices)
Expand Down
41 changes: 41 additions & 0 deletions tensorflow/compiler/tf2xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ cc_library(
":const_analysis",
":dump_graph",
":functionalize_control_flow",
":sharding_util",
":tf2xla_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
Expand Down Expand Up @@ -169,14 +171,46 @@ cc_library(
],
)

cc_library(
name = "sharding_util",
srcs = ["sharding_util.cc"],
hdrs = ["sharding_util.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)

tf_cc_test(
name = "sharding_util_test",
srcs = ["sharding_util_test.cc"],
deps = [
":sharding_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

# Internal targets below this point.

cc_library(
name = "tf2xla_util",
srcs = ["tf2xla_util.cc"],
hdrs = ["tf2xla_util.h"],
deps = [
":sharding_util",
":tf2xla_proto",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
Expand All @@ -190,8 +224,14 @@ tf_cc_test(
name = "tf2xla_util_test",
srcs = ["tf2xla_util_test.cc"],
deps = [
":sharding_util",
":tf2xla_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
Expand Down Expand Up @@ -350,6 +390,7 @@ cc_library(
srcs = ["functionalize_control_flow.cc"],
hdrs = ["functionalize_control_flow.h"],
deps = [
":tf2xla_util",
"//tensorflow/compiler/jit:graph_to_functiondef",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla:dump_graph",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/const_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Pad", "paddings"},
{"PadV2", "paddings"},
{"MirrorPad", "paddings"},
{"Multinomial", "num_samples"},
{"Prod", "reduction_indices"},
{"RandomStandardNormal", "shape"},
{"RandomUniform", "shape"},
Expand Down
14 changes: 13 additions & 1 deletion tensorflow/compiler/tf2xla/functionalize_control_flow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
Expand Down Expand Up @@ -405,14 +406,25 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
arg.merge->name());
}

// Find the Exit successor of the Switch.
// Update the device on the Identity outputs of the switch to match their
// target. These Identity outputs do not

// Loop over the switch node's output to:
// - Find the Exit successor.
// - Set the sharding on all Identity outputs of the switch. These
// identity nodes are values used by the loop body or condition.
// The Identity node may have the wrong device so copy the device from
// one of its outputs instead.
for (const Edge* edge : arg.switch_node->out_edges()) {
if (edge->src_output() == 0 && IsExit(edge->dst())) {
if (arg.exit != nullptr) {
return errors::InvalidArgument("Duplicate Exit successors to ",
arg.switch_node->name());
}
arg.exit = edge->dst();
} else if (StringPiece(edge->dst()->type_string()) == "Identity") {
TF_RETURN_IF_ERROR(
SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
}
}
}
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ tf_kernel_library(
"bias_ops.cc",
"binary_ops.cc",
"cast_op.cc",
"categorical_op.cc",
"concat_op.cc",
"const_op.cc",
"conv_ops.cc",
Expand Down
Loading

0 comments on commit ab1ca70

Please sign in to comment.