Skip to content

Commit

Permalink
Update tf.InfeedDequeueTuple -> xla_hlo.infeed legalization to insert…
Browse files Browse the repository at this point in the history
… a default sharding (device 0) to account for token result.

Compared to tf.InfeedDequeueTuple, xla_hlo.infeed has an additional token result. As number of results must match the number of shardings, a sharding is inserted at the end when legalizing to xla_hlo.infeed.

PiperOrigin-RevId: 307861971
Change-Id: I8f2828cc6036afc41a72dae88a2218771745e441
  • Loading branch information
andyly authored and tensorflower-gardener committed Apr 22, 2020
1 parent 7eb1c83 commit 1c67957
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"@llvm-project//llvm:support",
Expand Down
16 changes: 10 additions & 6 deletions tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1171,17 +1171,21 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {

// The following op sharding is used:
// Proto debug string:
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// type: TUPLE
// tuple_shardings {
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// }
// Serialized string:
// "\08\01\1A\01\01\22\01\00"
// "\08\02*\08\08\01\1A\01\01\22\01\00"

// CHECK-LABEL: infeed_dequeue_tuple_sharding
func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
// CHECK: "xla_hlo.infeed"
// CHECK-SAME: xla_hlo.sharding = "type: MAXIMAL\0Atile_assignment_dimensions: 1\0Atile_assignment_devices: 0\0A"
%0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
// An additional sharding is added at the end to account for token result.
// CHECK-SAME: xla_hlo.sharding = "type: TUPLE\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0A"
%0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
return %0 : tensor<8xi32>
}

Expand Down
13 changes: 11 additions & 2 deletions tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
Expand Down Expand Up @@ -3376,9 +3377,17 @@ class ConvertInfeedDequeueTupleOp
// _XlaSharding attribute in TF is a serialized string of the OpSharding
// proto, so convert to a text form here.
::xla::OpSharding sharding_proto;
if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
return failure();

// Token is a control signal and not a real data, so arbitrarily assign
// the token to device 0.
if (sharding_proto.type() == ::xla::OpSharding::TUPLE)
*sharding_proto.add_tuple_shardings() =
::xla::sharding_builder::AssignDevice(0);

std::string sharding_str;
if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()) ||
!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
&sharding_str))
return failure();

Expand Down

0 comments on commit 1c67957

Please sign in to comment.