Skip to content

Commit

Permalink
[TF:XLA] Start using XLA pooling library in tf2xla
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 207763624
  • Loading branch information
tensorflower-gardener committed Aug 7, 2018
1 parent 201a27e commit acb87b0
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 76 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/lib:pooling",
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/core:framework",
Expand Down
184 changes: 108 additions & 76 deletions tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
Expand Down Expand Up @@ -71,59 +72,53 @@ class PoolingOp : public XlaOpKernel {

int num_dims() const { return num_spatial_dims_ + 2; }

// Method that builds an initial value to use in reductions.
virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0;

// The reduction operation to apply to each window.
virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0;

// A post-processing operation to apply on the outputs of the ReduceWindow.
virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape) = 0;

void Compile(XlaOpKernelContext* ctx) override {
std::vector<int64> ksize = ksize_;
std::vector<int64> stride = stride_;
if (ctx->num_inputs() != 1) {
const TensorShape ksize_shape = ctx->InputShape(1);
// Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString()));
OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
ksize.clear();
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));

const TensorShape stride_shape = ctx->InputShape(2);
// Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString()));
OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
errors::InvalidArgument("Sliding window stride field must "
"specify ",
num_dims(), " dimensions"));
stride.clear();
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
protected:
xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
if (ctx->num_inputs() == 1) {
return ksize_;
}
const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
errors::InvalidArgument("Input to ", type_string(),
" operator must have ", num_dims(),
" dimensions"));
const TensorShape ksize_shape = ctx->InputShape(1);
// Validate input sizes.
if (!TensorShapeUtils::IsVector(ksize_shape)) {
return errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString());
}
if (ksize_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window ksize field must "
"specify ",
num_dims(), " dimensions");
}
std::vector<int64> ksize;
auto status = ctx->ConstantInputAsIntVector(1, &ksize);
if (!status.ok()) {
return status;
}
return ksize;
}

xla::XlaBuilder* const b = ctx->builder();
auto input =
XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_);
auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize,
stride, padding_);
auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
ctx->SetOutput(0,
PostProcessOutput(ctx, pooled, input_type(0), input_shape));
xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
if (ctx->num_inputs() == 1) {
return stride_;
}
const TensorShape stride_shape = ctx->InputShape(2);
// Validate input sizes.
if (!TensorShapeUtils::IsVector(stride_shape)) {
return errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString());
}
if (stride_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window stride field must "
"specify ",
num_dims(), " dimensions");
}
std::vector<int64> stride;
auto status = ctx->ConstantInputAsIntVector(2, &stride);
if (!status.ok()) {
return status;
}
return stride;
}

protected:
Expand All @@ -136,24 +131,48 @@ class PoolingOp : public XlaOpKernel {
xla::PrimitiveType xla_reduction_type_;
};

// Converts the tensor data format to the one required by the XLA pooling
// library.
xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
int num_spatial_dims) {
int num_dims = num_spatial_dims + 2;
int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
spatial_dimensions[spatial_dim] =
GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
}
return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
/*feature_dimension=*/feature_dimension,
/*spatial_dimensions=*/spatial_dimensions);
}

class MaxPoolOp : public PoolingOp {
public:
MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
: PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
/*reduction_type=*/ctx->input_type(0)) {}

xla::XlaOp InitValue(xla::XlaBuilder* b) override {
return xla::MinValue(b, xla_reduction_type_);
}
void Compile(XlaOpKernelContext* ctx) override {
auto ksize_or_error = GetKernelSize(ctx);
OP_REQUIRES_OK(ctx, ksize_or_error.status());
std::vector<int64> ksize = ksize_or_error.ValueOrDie();

const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
return ctx->GetOrCreateMax(reduction_type_);
}
auto stride_or_error = GetStride(ctx);
OP_REQUIRES_OK(ctx, stride_or_error.status());
std::vector<int64> stride = stride_or_error.ValueOrDie();

const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
errors::InvalidArgument("Input to ", type_string(),
" operator must have ", num_dims(),
" dimensions"));

xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape) override {
return output;
auto pooling =
xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
XlaTensorFormat(data_format_, input_shape.dims() - 2));
ctx->SetOutput(0, pooling);
}
};

Expand All @@ -180,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp {
};
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);

// Common computation shared between AvgPool and AvgPoolGrad. Divide each
// element of an image by the count of elements that contributed to that
// element during pooling.
// Divide each element of an image by the count of elements that contributed to
// that element during pooling.
static xla::XlaOp AvgPoolDivideByCount(
XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape, xla::Padding padding,
Expand Down Expand Up @@ -241,20 +259,34 @@ class AvgPoolOp : public PoolingOp {
/*reduction_type=*/
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}

xla::XlaOp InitValue(xla::XlaBuilder* b) override {
return xla::Zero(b, xla_reduction_type_);
}
void Compile(XlaOpKernelContext* ctx) override {
auto ksize_or_error = GetKernelSize(ctx);
OP_REQUIRES_OK(ctx, ksize_or_error.status());
std::vector<int64> ksize = ksize_or_error.ValueOrDie();

const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
return ctx->GetOrCreateAdd(reduction_type_);
}
auto stride_or_error = GetStride(ctx);
OP_REQUIRES_OK(ctx, stride_or_error.status());
std::vector<int64> stride = stride_or_error.ValueOrDie();

const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
errors::InvalidArgument("Input to ", type_string(),
" operator must have ", num_dims(),
" dimensions"));

xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape) override {
return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
ksize_, stride_, num_spatial_dims_,
data_format_);
auto xla_data_format =
XlaTensorFormat(data_format_, input_shape.dims() - 2);
auto spatial_padding = MakeSpatialPadding(
input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);

// Convert the input to the reduction type.
auto converted_input =
ConvertElementType(ctx->Input(0), xla_reduction_type_);
auto pooling =
xla::AvgPool(converted_input, ksize, stride, spatial_padding,
xla_data_format, padding_ == xla::Padding::kValid);
// Convert the pooling result back to the input type before returning it.
ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
}
};

Expand Down

0 comments on commit acb87b0

Please sign in to comment.