Skip to content

Commit

Permalink
Merge commit for internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vijay Vasudevan committed Jan 28, 2016
2 parents c68f145 + 72bf502 commit 44f318d
Show file tree
Hide file tree
Showing 36 changed files with 1,083 additions and 588 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
* ASSERT_OK / EXPECT_OK macros conflicted with external projects, so they were
renamed TF_ASSERT_OK, TF_EXPECT_OK. The existing macros are currently
maintained for short-term compatibility but will be removed.
* The non-public `nn.rnn` and the various `nn.seq2seq` methods now return
just the final state instead of the list of all states.


## Bug fixes

Expand Down
117 changes: 62 additions & 55 deletions tensorflow/core/common_runtime/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <algorithm>
#include <atomic>
#include <set>
#include <unordered_map>
#include <vector>
Expand All @@ -22,11 +23,13 @@ limitations under the License.

#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h"

namespace tensorflow {
Expand Down Expand Up @@ -57,15 +60,18 @@ void FindConstantFoldableNodes(const Graph* graph, ConstantFoldingOptions opts,
ReverseDFS(*graph, nullptr,
[&nodes, &node_set, &internal_node_inserted, opts](Node* n) {
if (n->IsConstant()) {
// Constants are definitely constant foldable
node_set.insert(n);
nodes.push_back(n);
// Constants with no control inputs (except from _SOURCE node)
// are definitely constant foldable.
if (n->in_edges().size() == 0 ||
(n->in_edges().size() == 1 &&
(*n->in_edges().begin())->src()->IsSource())) {
node_set.insert(n);
nodes.push_back(n);
}
} else if (IsConstantFoldable(n, opts.consider)) {
// Check whether the set of this node's in_nodes is completely
// included in
// the set of constant foldable nodes. If true, then this nodes
// is also
// constant foldable.
// included in the set of constant foldable nodes. If true,
// then this node is also constant foldable.
bool all_parents_constant = n->num_inputs() > 0;
for (const Node* parent : n->in_nodes()) {
if (node_set.count(parent) == 0) {
Expand All @@ -86,14 +92,16 @@ void FindConstantFoldableNodes(const Graph* graph, ConstantFoldingOptions opts,
}
}

typedef std::pair<Node*, int> NodeAndOutput;

// Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
// will contain copies of the nodes in 'nodes'. In addition, if there is an edge
// going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
// 'nodes', then 'nodes_to_fetch' will contain the mapping from the
// corresponding copy of 'n' in 'g' to 'n'.
// 'nodes', then 'tensors_to_fetch' will contain the mapping from the
// corresponding copy of 'n' and the edge number in 'g' to 'n'.
Graph* GetConstantGraph(const Graph* orig_graph,
const std::vector<Node*>& nodes,
std::unordered_map<Node*, Node*>* nodes_to_fetch) {
std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
Graph* constant_graph = new Graph(orig_graph->op_registry());
std::unordered_map<Node*, Node*> node_map;
std::set<Node*> already_added;
Expand All @@ -107,45 +115,51 @@ Graph* GetConstantGraph(const Graph* orig_graph,
already_added.insert(added);
for (const Edge* in_edge : n->in_edges()) {
Node* in = in_edge->src();
CHECK_GT(node_map.count(in), 0);
CHECK_GT(already_added.count(node_map[in]), 0);
CHECK_GT(node_map.count(in), 0) << n->DebugString() << " <-"
<< in->DebugString();
CHECK_GT(already_added.count(node_map[in]), 0) << in->DebugString();
constant_graph->AddEdge(node_map[in], in_edge->src_output(), added,
in_edge->dst_input());
}
}

for (auto const& added_nodes : node_map) {
bool should_fetch = false;
for (const Edge* out_edge : added_nodes.first->out_edges()) {
if (node_map.count(out_edge->dst()) == 0) {
should_fetch = true;
break;
}
tensors_to_fetch->insert(
{{added_nodes.second, out_edge->src_output()}, added_nodes.first});
}
if (should_fetch) {
nodes_to_fetch->insert({added_nodes.second, added_nodes.first});
}
}

return constant_graph;
}

void ReplaceNodeWithConstant(Graph* graph, Node* n, const Tensor& constant) {
std::vector<std::tuple<int, Node*, int>> old_edges;
int64 UniqueConstantId() {
static std::atomic_int_fast64_t id;
return id.fetch_add(1);
}

void ReplaceTensorWithConstant(Graph* graph, NodeAndOutput tensor,
const Tensor& constant) {
Node* n = tensor.first;
std::vector<const Edge*> edges_to_remove;
for (const Edge* out_edge : n->out_edges()) {
old_edges.push_back(std::make_tuple(out_edge->src_output(), out_edge->dst(),
out_edge->dst_input()));
if (out_edge->src_output() == tensor.second) {
edges_to_remove.push_back(out_edge);
}
}
string node_name = n->name();
graph->RemoveNode(n);
Node* constant_node;
TF_CHECK_OK(NodeBuilder(graph->NewName(node_name), "Const")
TF_CHECK_OK(NodeBuilder(strings::StrCat(graph->NewName(node_name), "__cf__",
UniqueConstantId()),
"Const")
.Attr("dtype", constant.dtype())
.Attr("value", constant)
.Finalize(graph, &constant_node));
for (auto edge : old_edges) {
graph->AddEdge(constant_node, std::get<0>(edge), std::get<1>(edge),
std::get<2>(edge));
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
}
}

Expand Down Expand Up @@ -218,6 +232,7 @@ class SimpleRendezvous : public Rendezvous {
} // namespace

bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
DumpGraph("Before", graph);
Device* device = GetCPUDevice();
thread::ThreadPool* thread_pool = GetThreadPool();
if (!device || !thread_pool) {
Expand All @@ -233,11 +248,12 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
return false;
}

std::unordered_map<Node*, Node*> nodes_to_fetch;
std::map<NodeAndOutput, Node*> tensors_to_fetch;
Graph* constant_graph =
GetConstantGraph(graph, constant_foldable_nodes, &nodes_to_fetch);
GetConstantGraph(graph, constant_foldable_nodes, &tensors_to_fetch);
DumpGraph("Constant graph", constant_graph);

if (nodes_to_fetch.empty()) {
if (tensors_to_fetch.empty()) {
VLOG(1) << "No constant nodes found that feed into the original graph.";
delete constant_graph;
return false;
Expand All @@ -252,21 +268,23 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
}

std::vector<Node*> fetch_nodes;
std::vector<string> nodes_to_fetch_names;
std::vector<Node*> nodes_to_replace;
for (auto n : nodes_to_fetch) {
nodes_to_fetch_names.push_back(n.first->name());
nodes_to_replace.push_back(n.second);
std::vector<string> tensors_to_fetch_names;
std::vector<NodeAndOutput> tensors_to_replace;
for (auto n : tensors_to_fetch) {
tensors_to_fetch_names.push_back(
strings::StrCat(n.first.first->name(), ":", n.first.second));
tensors_to_replace.push_back({n.second, n.first.second});
}
// For nodes that need to be fetched back from the constant_graph, attach Send
// nodes.
if (!subgraph::FetchOutputs(constant_graph, device->attributes(),
nodes_to_fetch_names, &name_index, &fetch_nodes)
tensors_to_fetch_names, &name_index, &fetch_nodes)
.ok()) {
VLOG(1) << "Could not fetch constants";
return false;
}

CHECK_EQ(fetch_nodes.size(), nodes_to_fetch.size());
CHECK_EQ(fetch_nodes.size(), tensors_to_fetch.size());

// Create the local executor and the Rendezvous for fetching back the
// constants.
Expand Down Expand Up @@ -311,17 +329,7 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
}
executor_done.WaitForNotification();

// Keep track of the nodes that will be orphaned once the internal nodes have
// been constant folded and replaced, so we can delete them later.
std::set<Node*> replaced_nodes_set(nodes_to_replace.begin(),
nodes_to_replace.end());
std::vector<Node*> to_delete;
for (Node* n : constant_foldable_nodes) {
if (replaced_nodes_set.count(n) == 0) {
to_delete.push_back(n);
}
}
// Fetch the constant nodes and replace the corresponding nodes in the
// Fetch the constant tensors and replace the corresponding tensors in the
// original graph with those constants.
for (size_t c = 0; c < fetch_nodes.size(); ++c) {
Tensor output;
Expand All @@ -336,15 +344,14 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
if (!s.ok() || is_dead) {
return c > 0;
}
VLOG(1) << "Replacing " << nodes_to_replace[c]->DebugString()
<< " with constant " << output.DebugString();
ReplaceNodeWithConstant(graph, nodes_to_replace[c], output);
VLOG(1) << "Replacing " << tensors_to_replace[c].first->DebugString()
<< " :: " << tensors_to_replace[c].second << " with constant "
<< output.DebugString();
ReplaceTensorWithConstant(graph, tensors_to_replace[c], output);
}

// Delete the orphaned nodes in the original graph.
for (Node* n : to_delete) {
graph->RemoveNode(n);
}
DumpGraph("After", graph);

return true;
}

Expand Down
99 changes: 85 additions & 14 deletions tensorflow/core/common_runtime/constant_folding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/tensor.h"
Expand Down Expand Up @@ -60,26 +61,39 @@ class ConstantFoldingTest : public ::testing::Test {
test::ExpectClose(t, test::AsTensor(values, shape));
}

template <typename T>
void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
TensorShape shape) {
EXPECT_TRUE(n->IsConstant());
const TensorProto* tensor_proto;
EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor_proto));
DataType dtype;
EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
Tensor t(dtype);
EXPECT_TRUE(t.FromProto(*tensor_proto));
test::ExpectTensorEqual<T>(t, test::AsTensor(values, shape));
}

// Construct the following graph
// s1 s2
// | |
// m1 m2
// / \ / \
// a b c
#define SIMPLE_GRAPH \
Reset(); \
Graph* g = g_.get(); \
Node* a = Constant<float>({1.0, 0.0, 0.0, 1.0}, {2, 2}); \
Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); \
Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2}); \
g->AddControlEdge(g->source_node(), a); \
g->AddControlEdge(g->source_node(), b); \
g->AddControlEdge(g->source_node(), c); \
Node* m1 = test::graph::Matmul(g, a, b, false, false); \
Node* s1 = test::graph::Send(g_.get(), m1, "m1", "sender", 0, "receiver"); \
Node* m2 = test::graph::Matmul(g, b, c, false, false); \
Node* s2 = test::graph::Send(g_.get(), m2, "m2", "sender", 0, "receiver"); \
g->AddControlEdge(s1, g->sink_node()); \
#define SIMPLE_GRAPH \
Reset(); \
Graph* g = g_.get(); \
Node* a = Constant<float>({1.0, 0.0, 0.0, 1.0}, {2, 2}); \
Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); \
Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2}); \
g->AddControlEdge(g->source_node(), a); \
g->AddControlEdge(g->source_node(), b); \
g->AddControlEdge(g->source_node(), c); \
Node* m1 = test::graph::Matmul(g, a, b, false, false); \
Node* s1 = test::graph::Send(g, m1, "m1", "sender", 0, "receiver"); \
Node* m2 = test::graph::Matmul(g, b, c, false, false); \
Node* s2 = test::graph::Send(g, m2, "m2", "sender", 0, "receiver"); \
g->AddControlEdge(s1, g->sink_node()); \
g->AddControlEdge(s2, g->sink_node());

std::unique_ptr<Graph> g_;
Expand Down Expand Up @@ -113,6 +127,63 @@ TEST_F(ConstantFoldingTest, ConsiderFunction) {
EXPECT_EQ(1, s2->num_inputs());
EXPECT_EQ(*(s2->in_nodes().begin()), m2);
}
#undef SIMPLE_GRAPH

TEST_F(ConstantFoldingTest, TwoOutputs) {
Reset();
Graph* g = g_.get();
Node* s0 = Constant<int>({1}, {1});
Node* s1 = Constant<int>({2, 2}, {2});
g->AddControlEdge(g->source_node(), s0);
g->AddControlEdge(g->source_node(), s1);
Node* b = test::graph::BroadcastGradientArgs(g, s0, s1);
Node* b0 = test::graph::Send(g, test::graph::Identity(g, b, 0),
strings::StrCat(b->name(), "0"), "sender", 0,
"receiver");
Node* b1 = test::graph::Send(g, test::graph::Identity(g, b, 1),
strings::StrCat(b->name(), "1"), "sender", 0,
"receiver");
g->AddControlEdge(b0, g->sink_node());
g->AddControlEdge(b1, g->sink_node());

EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, g));
EXPECT_EQ(1, b0->num_inputs());
ExpectNodeEqual<int>(*(b0->in_nodes().begin()), {0, 1}, {2});
EXPECT_EQ(1, b1->num_inputs());
ExpectNodeEqual<int>(*(b1->in_nodes().begin()), {}, {0});
}

TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
Reset();
Graph* g = g_.get();
Node* s0 = Constant<int>({1}, {1});
Node* s1 = Constant<int>({2, 2}, {2});
g->AddControlEdge(g->source_node(), s0);
g->AddControlEdge(g->source_node(), s1);
Node* b = test::graph::BroadcastGradientArgs(g, s0, s1);
Node* b0 = test::graph::Send(g, test::graph::Identity(g, b, 0),
strings::StrCat(b->name(), "0"), "sender", 0,
"receiver");
Node* b1_ident = test::graph::Identity(g, b, 1);
Node* b1 = test::graph::Send(g, b1_ident, strings::StrCat(b->name(), "1"),
"sender", 0, "receiver");
g->AddControlEdge(b0, g->sink_node());
g->AddControlEdge(b1, g->sink_node());

ConstantFoldingOptions opts;
opts.consider = [b1_ident](const Node* n) { return b1_ident != n; };
EXPECT_TRUE(DoConstantFolding(opts, g));
// 0th output of b should have been folded.
EXPECT_EQ(1, b0->num_inputs());
ExpectNodeEqual<int>(*(b0->in_nodes().begin()), {0, 1}, {2});
// 1st output of b should still be b1_ident. However, b1_ident's input must
// have been replaced with a constant.
EXPECT_EQ(1, b1->num_inputs());
EXPECT_EQ(*(b1->in_nodes().begin()), b1_ident);

EXPECT_EQ(1, b1_ident->num_inputs());
ExpectNodeEqual<int>(*(b1_ident->in_nodes().begin()), {}, {0});
}

} // namespace
} // namespace tensorflow
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
return Status::OK();
}

static void DumpGraph(StringPiece label, const Graph* g) {
void DumpGraph(StringPiece label, const Graph* g) {
// TODO(zhifengc): Change Graph to record #nodes.
VLOG(1) << "Graph " << label << " #edges " << g->edges().size();
if (VLOG_IS_ON(2)) {
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/common_runtime/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ bool RemoveListArrayConverter(Graph* g);
// multiple times by calling ExpandInlineFunctions a few times.
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph);

// Dump the contents of the "graph" to log files if the logging level is
// sufficiently high.
void DumpGraph(StringPiece label, const Graph* g);

// Applies graph rewrite optimization such as inlining, dead code
// removal, etc.
//
Expand Down
Loading

0 comments on commit 44f318d

Please sign in to comment.