Skip to content

Commit

Permalink
Reuse common_runtime function instantiation in Grappler function opti…
Browse files Browse the repository at this point in the history
…mizer.

PiperOrigin-RevId: 240864990
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Mar 28, 2019
1 parent 249d892 commit 92f736f
Show file tree
Hide file tree
Showing 14 changed files with 487 additions and 1,075 deletions.
6 changes: 4 additions & 2 deletions tensorflow/core/framework/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ class FunctionInstantiationHelper {
} else {
gnode->set_op(FunctionLibraryDefinition::kArgOp);
}
AddAttr("T", dtypes[i], gnode);
DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
AddAttr("T", dtype, gnode);
AddAttr("index", arg_index, gnode);
result_.arg_types.push_back(dtypes[i]);
++arg_index;
Expand Down Expand Up @@ -343,7 +344,8 @@ class FunctionInstantiationHelper {
gnode->set_op(FunctionLibraryDefinition::kRetOp);
}
AddInput(nodes_.size() - 1, item->nid, item->idx + i);
AddAttr("T", dtypes[i], gnode);
DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
AddAttr("T", dtype, gnode);
AddAttr("index", (*ret_index)++, gnode);
result_.ret_types.push_back(dtypes[i]);
}
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/grappler/costs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":utils",
"@com_google_absl//absl/types:optional",
"//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler:mutable_graph_view",
Expand Down
72 changes: 51 additions & 21 deletions tensorflow/core/grappler/costs/graph_properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/core/grappler/costs/graph_properties.h"

#include "absl/types/optional.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
Expand Down Expand Up @@ -603,22 +604,31 @@ class SymbolicShapeRefiner {
" was not previously added to SymbolicShapeRefiner.");
}

const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
it->second;
if (!maybe_grappler_function_item.has_value()) {
VLOG(3) << "Skip failed to instantiate function call: function_name="
<< function_node->op();

auto* ctx = GetNodeContext(function_node);
auto* ic = ctx->inference_context.get();
for (int i = 0; i < ic->num_outputs(); ++i) {
TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
}

return Status::OK();
}

// Copy (not reference) so that changes we make here (e.g., replacing
// Placeholder with Const) don't affect one in
// _Arg with Const and _Retval with Identity) don't affect one in
// fun_to_grappler_function_item_.
GrapplerFunctionItem grappler_function_item = it->second;
GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
MutableGraphView gv(&grappler_function_item.graph);

// Forward shapes from function input nodes to argument nodes.
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
auto& fun_input = grappler_function_item.input(i);
if (fun_input.placeholders.size() > 1) {
// TODO(jmdecker): Handle case with multiple input placeholders
return errors::Unimplemented(
"Input arguments with multiple placeholders are not yet "
"supported.");
}
NodeDef* fun_node = gv.GetNode(fun_input.input_name);
NodeDef* fun_node = gv.GetNode(fun_input.node_name);
const TensorId input_tensor = ParseTensorName(function_node->input(i));

if (IsControlInput(input_tensor)) {
Expand Down Expand Up @@ -649,11 +659,18 @@ class SymbolicShapeRefiner {
proto.mutable_dim(i)->set_size(-1);
}
}

// Turn _Arg node into a Placeholder. _Arg node is a system op without a
// valid shape function.
*attr_output_shape.mutable_shape() = proto;
fun_node->set_op("Placeholder");
(*fun_node->mutable_attr())["dtype"] = (*fun_node->mutable_attr())["T"];
(*fun_node->mutable_attr()).erase("index");
(*fun_node->mutable_attr()).erase("T");
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
}

// Replace input Placeholders with Consts, if values are known. Note that
// Replace input nodes with Consts, if values are known. Note that
// we don't check exceptions here as it's done in the above loop.
auto* ctx = GetNodeContext(function_node);
auto* ic = ctx->inference_context.get();
Expand Down Expand Up @@ -684,6 +701,15 @@ class SymbolicShapeRefiner {
}
}

// Replace output _Retval nodes with Identity nodes. _Retval is a system op
// without outputs and registered shape function.
for (const auto& output_arg : grappler_function_item.outputs()) {
NodeDef* output_node = gv.GetNode(output_arg.node_name);
DCHECK_EQ(output_node->op(), "_Retval");
output_node->set_op("Identity");
output_node->mutable_attr()->erase("index");
}

// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
Expand All @@ -694,16 +720,9 @@ class SymbolicShapeRefiner {
ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
nullptr);
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_nodes.size() > 1) {
// TODO(jmdecker): Handle case of multiple output tensors
return errors::Unimplemented(
"Output arguments with multiple output tensors are not yet "
"supported.");
}

// It is guaranteed that output_tensors does not contain any control
// inputs, so port_id >= 0.
TensorId out_tensor = ParseTensorName(out_arg.output_nodes[0]);
TensorId out_tensor = ParseTensorName(out_arg.node_name);

const NodeDef* retnode = gv.GetNode(out_tensor.node());
if (retnode == nullptr) {
Expand Down Expand Up @@ -1042,9 +1061,18 @@ class SymbolicShapeRefiner {
CHECK_NOTNULL(function_library_.Find(function_node->op()));

GrapplerFunctionItem grappler_function_item;
TF_RETURN_IF_ERROR(
Status function_instantiated =
MakeGrapplerFunctionItem(*function_def, function_library_,
graph_def_version_, &grappler_function_item));
graph_def_version_, &grappler_function_item);

// If function instantiation failed we will skip it during shape inference.
if (!function_instantiated.ok()) {
VLOG(3) << "Failed to instantiate a function. Error: "
<< function_instantiated.error_message();
fun_to_grappler_function_item_[function_def->signature().name()] =
absl::nullopt;
return Status::OK();
}

if (grappler_function_item.inputs().size() > function_node->input_size()) {
return errors::FailedPrecondition(
Expand Down Expand Up @@ -1691,7 +1719,9 @@ class SymbolicShapeRefiner {
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
std::unordered_map<string, GrapplerFunctionItem>
// Store function instantiations only for valid function. If function
// instantiation failed it will have an `absl::nullopt`.
std::unordered_map<string, absl::optional<GrapplerFunctionItem>>
fun_to_grappler_function_item_;
FunctionLibraryDefinition function_library_;
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/core/grappler/op_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ bool IsApproximateEqual(const NodeDef& node) {
return node.op() == "ApproximateEqual";
}

bool IsArg(const NodeDef& node) {
return node.op() == "_Arg" || node.op() == "_DeviceArg";
}

bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }

bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
Expand Down Expand Up @@ -419,6 +423,10 @@ bool IsRestore(const NodeDef& node) {
node.op() == "RestoreSlice");
}

bool IsRetval(const NodeDef& node) {
return node.op() == "_Retval" || node.op() == "_DeviceRetval";
}

bool IsReverse(const NodeDef& node) {
return node.op() == "Reverse" || node.op() == "ReverseV2";
}
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ bool IsAnyMaxPool(const NodeDef& node);
bool IsAnyMin(const NodeDef& node);
bool IsAnyMul(const NodeDef& node);
bool IsApproximateEqual(const NodeDef& node);
bool IsArg(const NodeDef& node);
bool IsArgMax(const NodeDef& node);
bool IsArgMin(const NodeDef& node);
bool IsAssert(const NodeDef& node);
Expand Down Expand Up @@ -137,6 +138,7 @@ bool IsRelu6Grad(const NodeDef& node);
bool IsReluGrad(const NodeDef& node);
bool IsReshape(const NodeDef& node);
bool IsRestore(const NodeDef& node);
bool IsRetval(const NodeDef& node);
bool IsReverse(const NodeDef& node);
bool IsReverseV2(const NodeDef& node);
bool IsRsqrt(const NodeDef& node);
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/core/grappler/optimizers/data/rebatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,19 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,

// Replace optimized function with a new FunctionDef.
TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, optimized_func));
} else {
VLOG(2) << "Failed to optimize dataset function. Error: "
<< s.error_message();
}
} else if (IsDatasetNodeOfType(node, kSourceDatasetOps)) {
return errors::InvalidArgument(
"Reached a source dataset: ", node.op(),
" without encountering a batch transformation.");
} else if (IsRetval(node)) {
// _Retvals added to the function body graph in place of function outputs.
NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
TF_RETURN_IF_ERROR(
RecursivelyHandleOp(*input_node, num_workers, flib, graph));
} else {
return errors::InvalidArgument("Encountered an unsupported op: ",
node.op());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
return false;
}
for (const auto& consumer : node_map_->GetOutputs(node.name())) {
if (node.input_size() > 1 && IsMerge(*consumer)) {
if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
return false;
}
if (IsSwitch(*input)) {
Expand Down
Loading

0 comments on commit 92f736f

Please sign in to comment.