Skip to content

Commit

Permalink
Prune unecessary Identity nodes
Browse files Browse the repository at this point in the history
Change: 151331705
  • Loading branch information
benoitsteiner authored and tensorflower-gardener committed Mar 27, 2017
1 parent 06365fb commit ab0e81f
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 7 deletions.
18 changes: 18 additions & 0 deletions tensorflow/core/grappler/optimizers/graph_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ GraphRewriter::GraphRewriter(const GrapplerItem& item) {
for (auto& node : item.graph.node()) {
nodes_[node.name()] = &node;
}

for (auto& node : item.graph.node()) {
for (const auto& input : node.input()) {
int position = 0;
string input_node_name = ParseNodeName(input, &position);
if (position < 0) {
// This is a control edge
auto itr = nodes_.find(input_node_name);
CHECK(itr != nodes_.end());
control_dependency_drivers_.insert(itr->second);
}
}
}
}

void GraphRewriter::ForwardInputs(
Expand All @@ -46,5 +59,10 @@ void GraphRewriter::ForwardInputs(
}
}

bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const {
return control_dependency_drivers_.find(&node) !=
control_dependency_drivers_.end();
}

} // end namespace grappler
} // end namespace tensorflow
5 changes: 5 additions & 0 deletions tensorflow/core/grappler/optimizers/graph_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ class GraphRewriter {
const std::unordered_set<const NodeDef*>& nodes_to_delete,
NodeDef* new_node);

// Returns true if at least one of the edges in the direct fanout of 'node' is
// a control dependency edge.
bool DrivesControlDependency(const NodeDef& node) const;

private:
std::unordered_map<string, const NodeDef*> nodes_;
std::unordered_set<const NodeDef*> control_dependency_drivers_;
};

} // end namespace grappler
Expand Down
19 changes: 15 additions & 4 deletions tensorflow/core/grappler/optimizers/model_pruner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
std::unordered_set<const NodeDef*> nodes_to_delete;
for (auto& node : item.graph.node()) {
// Remove the stop gradient nodes since they serve no purpose once the graph
// is built.
if (node.op() != "StopGradient") {
// is built. Also remove Identity ops.
if (node.op() != "StopGradient" && node.op() != "Identity") {
continue;
}
nodes_to_delete.insert(&node);
// Don't prune nodes that are explicitely placed.
if (!node.device().empty()) {
continue;
}
// Don't remove nodes that drive control dependencies.
if (!rewriter.DrivesControlDependency(node)) {
nodes_to_delete.insert(&node);
}
}

for (auto& node : item.graph.node()) {
Expand All @@ -46,11 +53,15 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
rewriter.ForwardInputs(node, nodes_to_delete, new_node);
}

LOG(INFO) << "Pruned " << nodes_to_delete.size()
<< " nodes from the graph. The graph now contains "
<< pruned_graph->node_size() " nodes.";

return Status::OK();
}

void ModelPruner::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) {
const GraphDef& pruned_graph, double result) {
// Nothing to do for ModelPruner.
}

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/grappler/optimizers/model_pruner.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ class ModelPruner : public GraphOptimizer {
string name() const override { return "model_pruner"; };

Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) override;
GraphDef* pruned_graph) override;

void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
const GraphDef& pruned_graph, double result) override;
};

} // end namespace grappler
Expand Down
77 changes: 76 additions & 1 deletion tensorflow/core/grappler/optimizers/model_pruner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TEST_F(ModelPrunerTest, NoPruning) {
}
}

TEST_F(ModelPrunerTest, SimplePruning) {
TEST_F(ModelPrunerTest, StopGradientPruning) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();

Expand Down Expand Up @@ -82,6 +82,81 @@ TEST_F(ModelPrunerTest, SimplePruning) {
EXPECT_EQ(NodeName(b.name()), new_e.input(0));
}

TEST_F(ModelPrunerTest, IdentityPruning) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();

Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::AddN(s.WithOpName("e"), {d});

GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));

// Force the placement of c. This should ensure it is preserved.
EXPECT_EQ("c", item.graph.node(2).name());
item.graph.mutable_node(2)->set_device("CPU");

ModelPruner pruner;
GraphDef output;
Status status = pruner.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);

EXPECT_EQ(4, output.node_size());
const NodeDef& new_a = output.node(0);
EXPECT_EQ(NodeName(a.name()), new_a.name());
const NodeDef& new_b = output.node(1);
EXPECT_EQ(NodeName(b.name()), new_b.name());
const NodeDef& new_c = output.node(2);
EXPECT_EQ(NodeName(c.name()), new_c.name());
const NodeDef& new_e = output.node(3);
EXPECT_EQ(NodeName(e.name()), new_e.name());

EXPECT_EQ(1, new_e.input_size());
EXPECT_EQ(NodeName(c.name()), new_e.input(0));
}

TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();

Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::AddN(s.WithOpName("e"), {d});

GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));

// Add a control dependency between c and e. This should ensure c is
// preserved.
EXPECT_EQ("c", item.graph.node(2).name());
EXPECT_EQ("e", item.graph.node(4).name());
*item.graph.mutable_node(4)->add_input() = "^c";

ModelPruner pruner;
GraphDef output;
Status status = pruner.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);

EXPECT_EQ(4, output.node_size());
const NodeDef& new_a = output.node(0);
EXPECT_EQ(NodeName(a.name()), new_a.name());
const NodeDef& new_b = output.node(1);
EXPECT_EQ(NodeName(b.name()), new_b.name());
const NodeDef& new_c = output.node(2);
EXPECT_EQ(NodeName(c.name()), new_c.name());
const NodeDef& new_e = output.node(3);
EXPECT_EQ(NodeName(e.name()), new_e.name());

EXPECT_EQ(2, new_e.input_size());
EXPECT_EQ(NodeName(c.name()), new_e.input(0));
EXPECT_EQ("^c", new_e.input(1));
}

} // namespace
} // namespace grappler
} // namespace tensorflow
3 changes: 3 additions & 0 deletions tensorflow/core/grappler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ string NodeName(const string& name);
// Get the trailing position number ":{digits}" (if any) of a node name.
int NodePosition(const string& name);

// Returns the node name and position in a single call.
string ParseNodeName(const string& name, int* position);

// Add a prefix to a node name
string AddPrefixToNodeName(const string& name, const string& prefix);

Expand Down

0 comments on commit ab0e81f

Please sign in to comment.