Skip to content

Commit

Permalink
Refactored Quantized Concat
Browse files Browse the repository at this point in the history
  • Loading branch information
mahmoud-abuzaina committed Mar 28, 2019
1 parent b5d67b7 commit 5a30ce4
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 56 deletions.
204 changes: 148 additions & 56 deletions tensorflow/core/kernels/mkl_concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
#include <vector>

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
Expand All @@ -30,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::concat;
using mkldnn::stream;
Expand All @@ -47,6 +47,45 @@ enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
// --------------------------------------------------------------------------
// Eigen Concat Op
// --------------------------------------------------------------------------
namespace {
template <typename T>
struct RequantizeCopier {
RequantizeCopier(
const std::vector<std::pair<float, float>>* input_min_and_max,
float output_min, float output_max)
: output_min(output_min), output_max(output_max) {
DCHECK(input_min_and_max);
this->input_min_and_max = input_min_and_max;
}

inline void Copy(T* dst, const T* src, int input_index, size_t n) {
const float input_min = (*input_min_and_max)[input_index].first;
const float input_max = (*input_min_and_max)[input_index].second;
if (input_min == output_min && input_max == output_max) {
DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v()));
memcpy(dst, src, n * sizeof(T));
} else {
Eigen::array<Eigen::DenseIndex, 1> dims;
dims[0] = n;
typename TTypes<T, 1>::UnalignedConstTensor input_array(src, dims);
typename TTypes<T, 1>::UnalignedTensor output_array(dst, dims);

QuantizedToFloatStruct<T> q2f(input_min, input_max);
auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
FloatToQuantizedStruct<T> f2q(output_min, output_max);
// RequantizeCopier::Copy is called from within a shard of computation, so
// don't use the threadpool device here, simply assign with default CPU
// device.
output_array = QUANTIZE_WITH_EIGEN(input_float, f2q, T);
}
}

float output_min;
float output_max;
const std::vector<std::pair<float, float>>* input_min_and_max;
};
} // namespace

template <typename Device, typename T, AxisArgumentName AxisArgName>
class EigenConcatBaseOp : public OpKernel {
public:
Expand All @@ -55,12 +94,45 @@ class EigenConcatBaseOp : public OpKernel {

explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}

void CalculateInputAndOutputRange(
const OpInputList& input_mins, const OpInputList& input_maxes,
const size_t N,
std::vector<std::pair<float, float>>* input_mins_and_maxes,
float* output_min, float* output_max) {
input_mins_and_maxes->reserve(N);
float overall_min = std::numeric_limits<float>::max();
float overall_max = std::numeric_limits<float>::lowest();
for (int i = 0; i < N; ++i) {
const float input_min = input_mins[i].flat<float>()(0);
const float input_max = input_maxes[i].flat<float>()(0);
input_mins_and_maxes->emplace_back(input_min, input_max);
overall_min = std::min(overall_min, input_min);
overall_max = std::max(overall_max, input_max);
}
if (std::is_signed<T>::value) {
// For signed, we want a symmetrical distribution including zero for the
// output, so pick a range that meets that need.
const float largest_value =
std::max(std::abs(overall_min), std::abs(overall_max));
*output_min = -largest_value;
*output_max = largest_value;
} else {
// For MKL quantization, we only support scaled mode, so the range is
// [0,m] for unsigned data
overall_min = std::min(0.0f, overall_min);
*output_min = overall_min;
*output_max = overall_max;
}
}

// Although, we modify Compute for this call to accept one extra param,
// we need to have empty Compute because Compute is pure virtual function.
void Compute(OpKernelContext* c) {}

void Compute(OpKernelContext* c, const std::vector<Tensor>& values,
const TensorShapeList& input_shapes) {
const TensorShapeList& input_shapes,
const OpInputList& input_mins, const OpInputList& input_maxes,
bool quantized_input) {
const Tensor* concat_dim_tensor;
const char* axis_attribute_name =
AxisArgName == NAME_IS_AXIS
Expand All @@ -80,13 +152,21 @@ class EigenConcatBaseOp : public OpKernel {
const TensorShape& input_shape = input_shapes[0];

int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
OP_REQUIRES(c,
(0 <= axis && axis < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range "
"[",
-input_dims, ", ", input_dims, "), but got ", concat_dim));
OP_REQUIRES(
c, (0 <= axis && axis < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range [",
-input_dims, ", ", input_dims, "), but got ", concat_dim));

float output_min = std::numeric_limits<float>::max();
float output_max = std::numeric_limits<float>::lowest();
std::vector<std::pair<float, float>> input_mins_and_maxes;
if (quantized_input) {
CalculateInputAndOutputRange(input_mins, input_maxes, N,
&input_mins_and_maxes, &output_min,
&output_max);
}
// Note that we reduce the concat of n-dimensional tensors into a two
// dimensional concat. Assuming the dimensions of any input/output
// tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
Expand All @@ -104,13 +184,12 @@ class EigenConcatBaseOp : public OpKernel {
const auto in = values[i];
const bool in_is_scalar = IsLegacyScalar(input_shapes[i]);
OP_REQUIRES(
c,
(input_shapes[i].dims() == input_dims) ||
(input_is_scalar && in_is_scalar),
c, (input_shapes[i].dims() == input_dims) ||
(input_is_scalar && in_is_scalar),
errors::InvalidArgument(
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
input_shape.DebugString(), " vs. shape[", i,
"] = ", input_shapes[i].DebugString()));
input_shape.DebugString(), " vs. shape[", i, "] = ",
input_shapes[i].DebugString()));
if (in.NumElements() > 0) {
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
Expand All @@ -131,7 +210,24 @@ class EigenConcatBaseOp : public OpKernel {
if (output->NumElements() > 0) {
int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
if (!quantized_input) {
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
} else {
ConcatCPUImpl<T>(
c->device(), inputs_flat, sizeof(T) /* cost_per_unit */,
RequantizeCopier<T>(&input_mins_and_maxes, output_min, output_max),
&output_flat);
}
}

if (quantized_input) {
Tensor* output_min_tensor = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(1, {}, &output_min_tensor));
output_min_tensor->flat<float>()(0) = output_min;

Tensor* output_max_tensor = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(2, {}, &output_max_tensor));
output_max_tensor->flat<float>()(0) = output_max;
}
}
};
Expand Down Expand Up @@ -229,7 +325,9 @@ class MklConcatOp : public OpKernel {
if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;

OpInputList input_mins, input_maxes;
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
bool quantized_input =
std::is_same<T, qint8>::value || std::is_same<T, quint8>::value;
if (quantized_input) {
// MKL-DNN concat does not support input tensors that have different
// ranges. Check if the ranges of the all input tensors are the same.
// If not, forward it to Eigen implementation.
Expand Down Expand Up @@ -262,17 +360,8 @@ class MklConcatOp : public OpKernel {

// Call Eigen library
if (invoke_eigen) {
// MKL-DNN quantized concat does not support input tensors with
// different ranges.
// TODO (mabuzain): Add quantized version of CallEigen() to support
// this case.
OP_REQUIRES(
context,
(!std::is_same<T, qint8>::value && !std::is_same<T, quint8>::value),
errors::Unimplemented("MKL DNN quantized concat does not "
"support input tensors that have "
"different ranges"));
CallEigenVersion(context, input_tensors, mkl_input_shapes);
CallEigenVersion(context, input_tensors, input_mins, input_maxes,
mkl_input_shapes, quantized_input);
return;
}

Expand Down Expand Up @@ -421,7 +510,7 @@ class MklConcatOp : public OpKernel {
}
AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
dnn_shape_dst);
DCHECK(dst_tensor == nullptr) << "Output tensor pointer is NULL";
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";

if (dnn_shape_dst.IsMklTensor()) dst_md = dnn_shape_dst.GetMklLayout();
dst.SetUsrMem(dst_md, dst_tensor);
Expand All @@ -432,7 +521,7 @@ class MklConcatOp : public OpKernel {
stream(stream::kind::eager).submit(net).wait();

// For quantized concat, min and max outputs are also computed.
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
if (quantized_input) {
Tensor* output_min = nullptr;
Tensor* output_max = nullptr;
MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
Expand All @@ -456,56 +545,59 @@ class MklConcatOp : public OpKernel {

AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
dnn_shape_dst);
DCHECK(dst_tensor == nullptr) << "Output tensor pointer is NULL";
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
}

} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
}
}

void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
const MklDnnShapeList& mkl_input_shapes) {
CHECK_EQ(values.size(), mkl_input_shapes.size());

std::vector<Tensor> converted_values;
const OpInputList& input_mins,
const OpInputList& input_maxes,
const MklDnnShapeList& mkl_input_shapes,
bool quantized_input) {
size_t num_mkl_input_shapes = mkl_input_shapes.size();
CHECK_EQ(values.size(), num_mkl_input_shapes);
std::vector<Tensor> converted_values(num_mkl_input_shapes);
TensorShapeList tf_input_shapes;
for (int i = 0; i < mkl_input_shapes.size(); i++) {
for (size_t i = 0; i < num_mkl_input_shapes; ++i) {
if (mkl_input_shapes[i].IsMklTensor()) {
// do conversion from MKL to TF
Tensor tmp_tensor =
ConvertMklToTF<T>(context, values[i], mkl_input_shapes[i]);
converted_values.push_back(tmp_tensor);
converted_values[i] = tmp_tensor;
tf_input_shapes.push_back(mkl_input_shapes[i].GetTfShape());
} else {
// no conversion since it is TF tensor already
converted_values.push_back(values[i]);
converted_values[i] = values[i];
tf_input_shapes.push_back(values[i].shape());
}
}

// Call Eigen concat.
eigen_concat_op_.Compute(context, converted_values, tf_input_shapes);

// Set output Mkl tensor for this op.
MklDnnShape dnn_shape_output;
dnn_shape_output.SetMklTensor(false);
dnn_shape_output.SetDimensions(4);
Tensor* output_tensor = nullptr;
TensorShape tf_shape_output;
tf_shape_output.AddDim(dnn_shape_output.GetSerializeBufferSize());
OP_REQUIRES_OK(context,
context->allocate_output(
GetTensorMetaDataIndex(0, context->num_outputs()),
tf_shape_output, &output_tensor));
dnn_shape_output.SerializeMklDnnShape(
output_tensor->flat<uint8>().data(),
output_tensor->flat<uint8>().size() * sizeof(uint8));
eigen_concat_op_.Compute(context, converted_values, tf_input_shapes,
input_mins, input_maxes, quantized_input);

// Get the number of dims from first input since all input tensors
// should have same rank.
size_t dims = values[0].shape().dims();
MklDnnShape output_data_mkl_shape;
output_data_mkl_shape.SetMklTensor(false);
output_data_mkl_shape.SetDimensions(dims);
AllocateOutputSetMklShape(context, 0, output_data_mkl_shape);
if (quantized_input) {
MklDnnShape output_min_max_mkl_shape;
output_min_max_mkl_shape.SetMklTensor(false);
AllocateOutputSetMklShape(context, 1, output_min_max_mkl_shape);
AllocateOutputSetMklShape(context, 2, output_min_max_mkl_shape);
}
}

// This method finds the most common format across all MKL inputs
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/mkl_quantized_concat_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ TEST_F(QuantizedConcatTest, Small8BitSameRange) {
TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f);
}

TEST_F(QuantizedConcatTest, Small8BitDifferentRange) {
TestSmall8Bit(0.0f, 255.0f, 0.0f, 25.0f);
}

void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max,
float second_min, float second_max) {
TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
Expand Down

0 comments on commit 5a30ce4

Please sign in to comment.