From 4b88c19912f97304d137b7b45619949bc5e7c792 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Fri, 17 Jun 2016 09:37:34 +0800 Subject: [PATCH] Support ndim up to 7 for binary broadcasting operators + Accelerate reducing OPs by calling reduce_except_dim if possible. + Add `/bigobj` to CMakeList (#2418) Reshape the lhs and rhs to ndim=3 if possible otherwise reshape them into ndim=7. --- CMakeLists.txt | 2 +- include/mxnet/operator_util.h | 46 +- src/operator/broadcast_reduce_op-inl.h | 116 ++- src/operator/broadcast_reduce_op_common.h | 165 +++++ .../elementwise_binary_broadcast_op-inl.h | 678 +++++++----------- tests/python/unittest/test_ndarray.py | 2 +- tests/python/unittest/test_operator.py | 47 +- 7 files changed, 532 insertions(+), 524 deletions(-) create mode 100644 src/operator/broadcast_reduce_op_common.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 7613fe00375b..de8d1e85360d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,7 +29,7 @@ if(MSVC) add_definitions(-D_CRT_SECURE_NO_WARNINGS) add_definitions(-DMXNET_EXPORTS) set(CMAKE_C_FLAGS "/MP") - set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} /bigobj") else(MSVC) include(CheckCXXCompilerFlag) check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) diff --git a/include/mxnet/operator_util.h b/include/mxnet/operator_util.h index f96b85108b47..94eb994d07e1 100644 --- a/include/mxnet/operator_util.h +++ b/include/mxnet/operator_util.h @@ -11,6 +11,10 @@ #ifndef MXNET_OPERATOR_UTIL_H_ #define MXNET_OPERATOR_UTIL_H_ +#ifdef _MSC_VER +#pragma warning(disable:4503) // disable warning: decorated name length exceeded. +#endif + #include #include #include @@ -412,47 +416,9 @@ class SimpleOpRegistry { } /*! -* \brief cast dynamic range variable into static variable -* \param var the source value, constrained to be between 1 and 5 -* \param NDIM the const NDIM that can be used in the template +* \brief Maximum ndim supported for special operators like broadcasting with non contiguous lhs/rhs */ -#define MXNET_RANGE_SWITCH(var, NDIM, ...) \ - { \ - switch (var) { \ - case 1: \ - { \ - static const int NDIM = 1; \ - {__VA_ARGS__} \ - } \ - break; \ - case 2: \ - { \ - static const int NDIM = 2; \ - {__VA_ARGS__} \ - } \ - break; \ - case 3: \ - { \ - static const int NDIM = 3; \ - {__VA_ARGS__} \ - } \ - break; \ - case 4: \ - { \ - static const int NDIM = 4; \ - {__VA_ARGS__} \ - } \ - break; \ - case 5: \ - { \ - static const int NDIM = 5; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - LOG(FATAL) << "Only support ndim=1 to 5."; \ - } \ - } +#define MXNET_SPECIAL_MAX_NDIM 7 //-------------------------------------------------------------- diff --git a/src/operator/broadcast_reduce_op-inl.h b/src/operator/broadcast_reduce_op-inl.h index f43bafbc16da..fa6b7fbf106a 100644 --- a/src/operator/broadcast_reduce_op-inl.h +++ b/src/operator/broadcast_reduce_op-inl.h @@ -103,11 +103,14 @@ void L2Norm(const TBlob &src, OpReqType req, RunContext ctx) { mshadow::Stream *s = ctx.get_stream(); - mshadow::Tensor out = ret->get(s); - mshadow::Tensor in = - src.get_with_shape(mshadow::Shape1(src.shape_.Size()), s); - mshadow::VectorDot(out, in, in); - out = mshadow::expr::F(out); + CHECK_EQ(src.type_flag_, ret->type_flag_); + MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Tensor out = ret->get(s); + mshadow::Tensor in = + src.get_with_shape(mshadow::Shape1(src.shape_.Size()), s); + mshadow::VectorDot(out, in, in); + ASSIGN_DISPATCH(out, req, mshadow::expr::F(out)); + }); } template @@ -117,10 +120,13 @@ void Reduce(const TBlob &src, OpReqType req, RunContext ctx) { mshadow::Stream *s = ctx.get_stream(); - mshadow::Tensor out = ret->get(s); - mshadow::Tensor in = - src.get_with_shape(mshadow::Shape2(1, src.shape_.Size()), s); - out = mshadow::expr::reduce_except_dim<0, Reducer>(in); + CHECK_EQ(src.type_flag_, ret->type_flag_); + MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Tensor out = ret->get(s); + mshadow::Tensor in = + src.get_with_shape(mshadow::Shape2(1, src.shape_.Size()), s); + ASSIGN_DISPATCH(out, req, (mshadow::expr::reduce_except_dim<0, Reducer>(in))); + }); } // backward function that takes input value of the op @@ -135,7 +141,7 @@ void SumBackward_(const OutputGrad& scale, mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(in_grad->type_flag_, scale.data.type_flag_) << "Unary function only support input/output with the same type"; - MSHADOW_TYPE_SWITCH(in_grad->type_flag_, DType, { + MSHADOW_REAL_TYPE_SWITCH(in_grad->type_flag_, DType, { mshadow::Tensor mscale = scale.data.get(s); mshadow::Tensor igrad = in_grad->FlatTo2D(s); ASSIGN_DISPATCH(igrad, req, @@ -143,7 +149,7 @@ void SumBackward_(const OutputGrad& scale, }); } -template +template void ReduceChannel(const TBlob &src, const EnvArguments& env, TBlob *ret, @@ -153,13 +159,17 @@ void ReduceChannel(const TBlob &src, using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - Tensor out = ret->get_with_shape( - Shape2(src.shape_[0], src.Size()/src.shape_[0]/src.shape_[1]), - s); - Tensor in = src.get_with_shape( - Shape3(src.shape_[0], src.shape_[1], src.Size()/src.shape_[0]/src.shape_[1]), + CHECK_EQ(src.type_flag_, ret->type_flag_); + MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, { + Tensor out = ret->get_with_shape( + Shape2(src.shape_[0], src.Size() / src.shape_[0] / src.shape_[1]), s); - out = reduce_with_axis(in, 1); + Tensor in = src.get_with_shape( + Shape3(src.shape_[0], src.shape_[1], src.Size() / src.shape_[0] / src.shape_[1]), + s); + CHECK(req != kAddTo) << "AddTo is not supported"; + ASSIGN_DISPATCH(out, req, (reduce_with_axis(in, 1))); + }); } // return a shape of ReduceChannel output @@ -184,13 +194,16 @@ void ReduceAxisImpl_(const TBlob &src, bool keepdims) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); + CHECK_EQ(src.type_flag_, ret->type_flag_); if (-1 == axis) { // Reduce all dimensions if axis == -1 - mshadow::Tensor in = - src.get_with_shape(mshadow::Shape2(1, src.shape_.Size()), s); - mshadow::Tensor out = - ret->get_with_shape(mshadow::Shape1(ret->shape_.Size()), s); - out = mshadow::expr::reduce_except_dim<0, Reducer>(in); + MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Tensor in = + src.get_with_shape(mshadow::Shape2(src.shape_.Size(), 1), s); + mshadow::Tensor out = + ret->get_with_shape(mshadow::Shape1(ret->shape_.Size()), s); + ASSIGN_DISPATCH(out, req, (reduce_except_dim<1, Reducer>(in))); + }); return; } int trailing = 1; @@ -202,11 +215,46 @@ void ReduceAxisImpl_(const TBlob &src, trailing *= src.shape_[i]; } } - mshadow::Tensor in = - src.get_with_shape(mshadow::Shape3(leading, src.shape_[axis], trailing), s); - mshadow::Tensor out = - ret->get_with_shape(mshadow::Shape2(leading, trailing), s); - out = mshadow::expr::reduce_with_axis(in, 1); + if (get_mask) { + // If get_mask is on, we have to use the slower `reduce_with_axis` + // since reduce_except_dim does not support the flag. + MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Tensor in = + src.get_with_shape(mshadow::Shape3(leading, src.shape_[axis], trailing), s); + mshadow::Tensor out = + ret->get_with_shape(mshadow::Shape2(leading, trailing), s); + CHECK(req != kAddTo) << "AddTo is not supported for `get_mask = true`"; + ASSIGN_DISPATCH(out, req, (reduce_with_axis(in, 1))); + }); + return; + } + if (1 == leading) { + MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Tensor in = + src.get_with_shape(mshadow::Shape2(src.shape_[axis], trailing), s); + mshadow::Tensor out = + ret->get_with_shape(mshadow::Shape1(trailing), s); + ASSIGN_DISPATCH(out, req, (reduce_except_dim<1, Reducer>(in))); + }); + } else if (1 == trailing) { + MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Tensor in = + src.get_with_shape(mshadow::Shape2(leading, src.shape_[axis]), s); + mshadow::Tensor out = + ret->get_with_shape(mshadow::Shape1(leading), s); + ASSIGN_DISPATCH(out, req, (reduce_except_dim<1, Reducer>(in.T()))); + }); + } else { + MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Tensor in = + src.get_with_shape(mshadow::Shape3(leading, src.shape_[axis], trailing), s); + mshadow::Tensor out = + ret->get_with_shape(mshadow::Shape1(leading * trailing), s); + ASSIGN_DISPATCH(out, req, + (reduce_except_dim<1, Reducer>(reshape(swapaxis<1, 0>(in), + mshadow::Shape2(src.shape_[axis], leading * trailing))))); + }); + } } // Broadcast the given axis to the given broadcasting size @@ -240,11 +288,13 @@ void BroadcastAxisImpl_(const TBlob &src, trailing *= ret->shape_[i]; } } - mshadow::Tensor in = - src.get_with_shape(mshadow::Shape2(leading, trailing), s); - mshadow::Tensor out = - ret->get_with_shape(mshadow::Shape3(leading, bsize, trailing), s); - out = mshadow::expr::broadcast_with_axis(in, 0, bsize); + MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { + mshadow::Tensor in = + src.get_with_shape(mshadow::Shape2(leading, trailing), s); + mshadow::Tensor out = + ret->get_with_shape(mshadow::Shape3(leading, bsize, trailing), s); + ASSIGN_DISPATCH(out, req, broadcast_with_axis(in, 0, bsize)); + }); } // Forward pass of reduce over the given axis @@ -386,7 +436,7 @@ MXNET_REGISTER_SIMPLE_OP(sum_axis, XPU) // argmax channel MXNET_REGISTER_SIMPLE_OP(argmax_channel, XPU) -.set_function(XPU::kDevMask, ReduceChannel, +.set_function(XPU::kDevMask, ReduceChannel, kNoInplace, kNotRegisterSymbolic) .set_shape_function(ReduceChannelShape) .describe("Take argmax indices of each channel of the src." diff --git a/src/operator/broadcast_reduce_op_common.h b/src/operator/broadcast_reduce_op_common.h new file mode 100644 index 000000000000..4ec50d4b3b56 --- /dev/null +++ b/src/operator/broadcast_reduce_op_common.h @@ -0,0 +1,165 @@ +/*! +* Copyright (c) 2016 by Contributors +* \file broadcast_reduce_op_common.h +* \brief common function used for broadcasting and reducing +* \author Xingjian Shi +*/ +#ifndef MXNET_OPERATOR_BROADCAST_REDUCE_OP_COMMON_H_ +#define MXNET_OPERATOR_BROADCAST_REDUCE_OP_COMMON_H_ +#include +#include +#include +#include + +namespace mxnet { +namespace op { + +/*! +* \brief Check if the axes are continuous + get reducing size. E.g (1, 3) -> false, (1,2,3) -> true +* \param is_contiguous_axes whether the axes is contiguous +* \param reducing_size product of source shape in the given axes +* \param axes +* \param src_shape shape of the source tensor +*/ +inline void CheckContiguousAxes_(bool *is_contiguous_axes, index_t *reducing_size, + const mshadow::TShape &axes, const mshadow::TShape &src_shape) { + *is_contiguous_axes = true; + *reducing_size = 1; + for (index_t i = 0; i < axes.ndim(); ++i) { + *reducing_size *= src_shape[axes[i]]; + if (i > 0) { + *is_contiguous_axes = *is_contiguous_axes && (axes[i] == (axes[i - 1] + 1)); + CHECK(axes[i - 1] < axes[i]) << "axes must be in increasing order, received axes=" << axes; + } + } +} + +template +inline void CheckContiguousAxes_(bool *is_contiguous_axes, index_t *reducing_size, + const mshadow::TShape &axes, const mshadow::Shape &src_shape) { + CheckContiguousAxes_(is_contiguous_axes, reducing_size, axes, + TShape(src_shape.shape_, src_shape.shape_ + dimsrc)); +} + +inline TShape GetBroadcastingAxes_(const mshadow::TShape &src_shape, + const mshadow::TShape &target_shape) { + std::vector axes_vec; + CHECK_EQ(target_shape.ndim(), src_shape.ndim()); + for (int i = 0; i < src_shape.ndim(); ++i) { + if (src_shape[i] != target_shape[i]) { + CHECK_EQ(src_shape[i], 1) << "broadcastsing axis must have size 1, received src_shape=" + << src_shape << " target_shape=" << target_shape; + axes_vec.push_back(i); + } + } + TShape axes = TShape(axes_vec.begin(), axes_vec.end()); + return axes; +} + +/*! +* \brief a reduce over multiple axes and assign to the output tensor. +* \param out output tensor, must have dim 1 +* \param src the source expression +* \param axes the given axes, should be in increasing order +* \tparam Reducer type of the reducing operation +* \tparam xpu +* \tparam SrcExp the src expression template +* \tparam etype type of expression +*/ +template +void ReduceAxesAssign(mshadow::Tensor out, const OpReqType req, + const SrcExp &src_, const TShape &axes) { + using namespace mshadow; + using namespace mshadow::expr; + static const int dimsrc = ExpInfo::kDim; + CHECK(axes.ndim() <= dimsrc); + Shape src_shape = ShapeCheck::Check(src_); + + // 1. Check if the axes has size 0, if so, no reducing is needed. + if (0 == axes.ndim()) { + ASSIGN_DISPATCH(out, req, reshape(src_, Shape1(src_shape.ProdShape(0, dimsrc)))); + return; + } + + // 2. Check if we want to reduce over contiguous axes and get the reducing size. + // e.g. (1,2,3) --> contiguous, (1,3) --> noncontiguous + bool is_contiguous_axes = true; + index_t reducing_size = 1; + CheckContiguousAxes_(&is_contiguous_axes, &reducing_size, axes, src_shape); + + // 3. For contiguous axes, we can always reshape them to (leading, reducing_size, trailing) + // and we can then simplify the combination of mshadow symbols. + if (is_contiguous_axes) { + index_t leading = 1; + index_t trailing = 1; + for (index_t i = 0; i < dimsrc; ++i) { + if (i < axes[0]) { + leading *= src_shape[i]; + } else if (i > axes[axes.ndim() - 1]) { + trailing *= src_shape[i]; + } + } + if (1 == leading) { + ASSIGN_DISPATCH(out, req, + (reduce_except_dim<1, Reducer>(reshape(src_, Shape2(reducing_size, trailing))))); + } else { + ASSIGN_DISPATCH(out, req, (reduce_except_dim<1, Reducer>( + reshape(swapaxis<1, 0>(reshape(src_, Shape3(leading, reducing_size, trailing))), + Shape2(reducing_size, leading * trailing))))); + } + return; + } + // 4. For non-contiguous axes, we need to push axes to the front of the shape vector then reduce. + // E.g axes = (1, 2), dim = 6 => transpose_shape = (1, 2, 0, 3, 4, 5) + Shape transpose_shape = src_shape; + index_t remaining_size = 1; + for (index_t i = 0; i < axes.ndim(); ++i) { + transpose_shape[i] = axes[i]; + if (i > 0) { + for (index_t j = axes[i - 1] + 1; j < axes[i]; ++j) { + transpose_shape[axes.ndim() - i + j] = j; + remaining_size *= src_shape[j]; + } + } + if (axes.ndim() - 1 == i) { + for (index_t j = axes[axes.ndim() - 1] + 1; j < dimsrc; ++j) { + transpose_shape[j] = j; + remaining_size *= src_shape[j]; + } + } + if (0 == i) { + for (index_t j = 0; j < axes[0]; ++j) { + transpose_shape[axes.ndim() - i + j] = j; + remaining_size *= src_shape[j]; + } + } + } + ASSIGN_DISPATCH(out, req, + (reduce_except_dim<1, Reducer>(reshape(transpose(src_, transpose_shape), + Shape2(reducing_size, remaining_size))))); +} + +/*! +* \brief a reduce to the given shape and assign to the output tensor. +* \param out output tensor, must have dim 1 +* \param src the source expression +* \param target_shape shape of the target tensor, must have size 1 for the reduction axes +* \tparam Reducer type of the reducing operation +* \tparam xpu +* \tparam SrcExp the src expression template +* \tparam etype type of expression +*/ +template +void ReduceToAssign(mshadow::Tensor out, const OpReqType req, + const TShape &target_shape, const SrcExp &src_) { + using namespace mshadow; + using namespace mshadow::expr; + static const int dimsrc = ExpInfo::kDim; + Shape src_shape = ShapeCheck::Check(src_); + TShape axes = GetBroadcastingAxes_(target_shape, + TShape(src_shape.shape_, src_shape.shape_ + dimsrc)); + ReduceAxesAssign(out, req, src_, axes); +} +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_BROADCAST_REDUCE_OP_COMMON_H_ diff --git a/src/operator/elementwise_binary_broadcast_op-inl.h b/src/operator/elementwise_binary_broadcast_op-inl.h index b210998e2775..89fedf5cc0c9 100644 --- a/src/operator/elementwise_binary_broadcast_op-inl.h +++ b/src/operator/elementwise_binary_broadcast_op-inl.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2016 by Contributors * \file elementwise_binary_broadcast_op-inl.h * \brief Function defintion of elementwise binary operators with broadcast * @@ -26,16 +26,13 @@ * * Here are examples of shapes that do not broadcast: * - * A (3d tensor): 15 x 3 x 5 - * B (3d tensor): 15 x 1 x 5 # the diminsions for broadcasting should be continous - * * A (1d tensor): 3 * B (1d tensor): 4 # trailing dimensions do not match * * A (2d tensor): 1 x 2 x 1 * B (3d tensor): 8 x 4 x 3 # second from last dimensions mismatched * - * When no broadcast is need, it fails back to elementwise_binary_op-inl.h + * When no broadcast is need, it falls back to elementwise_binary_op-inl.h */ #ifndef MXNET_OPERATOR_ELEMENTWISE_BINARY_BROADCAST_OP_INL_H_ #define MXNET_OPERATOR_ELEMENTWISE_BINARY_BROADCAST_OP_INL_H_ @@ -44,6 +41,7 @@ #include #include #include "./mshadow_op.h" +#include "./broadcast_reduce_op_common.h" #if defined(__CUDACC__) #define XPU gpu @@ -56,8 +54,7 @@ namespace op { inline bool IsBroadcastNeeded_(const TShape& lhs, const TShape& rhs) { - // force ndim to be equal. do not smartly padding dims with 1s, which may - // confuse users + // force ndim to be equal. do not smartly padding dims with 1s, which may confuse users CHECK_EQ(lhs.ndim(), rhs.ndim()); for (index_t i = 0; i < lhs.ndim(); ++i) { if (lhs[i] != rhs[i]) return true; @@ -65,7 +62,6 @@ inline bool IsBroadcastNeeded_(const TShape& lhs, return false; } - inline TShape BinaryBroadcastShape_(const TShape& lhs, const TShape& rhs, const EnvArguments& env) { @@ -74,96 +70,66 @@ inline TShape BinaryBroadcastShape_(const TShape& lhs, for (size_t i = 0; i < ret.size(); ++i) { ret[i] = std::max(lhs[i], rhs[i]); } - // check - for (int h = 0; h < 2; ++h) { - const TShape& inp = h == 0 ? lhs : rhs; - int contdim = 0; - for (size_t i = 0; i < inp.ndim(); ++i) { - if (inp[i] != 1) { - CHECK_EQ(inp[i], ret[i]) << "broadcast error on index " << i << ". " - << "lhs = " << lhs << "; rhs = " << rhs; - } - if (inp[i] == ret[i]) { - if (i == 0 || inp[i-1] != ret[i-1]) ++contdim; - } - } - CHECK_LE(contdim, 1) << "broadcast dimensions are not continuous. " - << "lhs = " << lhs << "; rhs = " << rhs; - } return TShape(ret.begin(), ret.end()); } -inline void GetBroadcastShape_(const TShape& lhs, - const TShape& rhs, - TShape* ret_reshaped, - int* lhs_broadcast_axis, - int* rhs_broadcast_axis) { - TShape ret = BinaryBroadcastShape_(lhs, rhs, EnvArguments()); - int n = static_cast(ret.ndim()); - int pos[4] = {0, n, n, n}; - for (int h = 0; h < 2; ++h) { - const TShape& inp = h == 0 ? lhs : rhs; - for (int i = 0; i < n; ++i) { - if (inp[i] == ret[i]) { - pos[h*2] = i; break; - } - } - for (int i = n; i > 0; --i) { - if (inp[i-1] == ret[i-1]) { - pos[h*2+1] = i; break; - } - } - } - bool no_broadcast_lhs = pos[0] == 0 && pos[1] == n; - bool no_broadcast_rhs = pos[2] == 0 && pos[3] == n; - int pos_ordered[4] = {0, -1, -1, n}; - if (no_broadcast_lhs && no_broadcast_rhs) { - // no broadcast - LOG(FATAL) << "no broadcast is needed"; - } else if (no_broadcast_lhs && !no_broadcast_rhs) { - // only broadcast rhs - *rhs_broadcast_axis = 1; - *lhs_broadcast_axis = -1; - pos_ordered[1] = pos[2]; - pos_ordered[2] = pos[3]; - } else if (!no_broadcast_lhs && no_broadcast_rhs) { - // only broadcast lhs - *rhs_broadcast_axis = -1; - *lhs_broadcast_axis = 1; - pos_ordered[1] = pos[0]; - pos_ordered[2] = pos[1]; - } else { - // broadcast both lhs and rhs - int p; - if (pos[0] <= pos[2]) { - CHECK(pos[0] == 0 && pos[1] == pos[2] && pos[3] == n) - << "broadcast shape error: lhs = " << lhs << "; rhs = " << rhs; - *lhs_broadcast_axis = 0; - *rhs_broadcast_axis = 1; - p = pos[1]; +inline void InferBroadcastNewShapes_(bool *do_opt, + TShape *new_lhs_shape, TShape *new_rhs_shape, TShape *new_out_shape, + const TShape &lhs_shape, const TShape &rhs_shape, const TShape &out_shape) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK((lhs_shape.ndim() == rhs_shape.ndim()) && (rhs_shape.ndim() == out_shape.ndim())) << + "ndim inconsistency, lhs_shape=" << lhs_shape << ", rhs_shape=" << rhs_shape << + ", out_shape=" << out_shape; + *do_opt = false; + TShape lhs_axes = GetBroadcastingAxes_(lhs_shape, out_shape); + TShape rhs_axes = GetBroadcastingAxes_(rhs_shape, out_shape); + bool lhs_contiguous, rhs_contiguous; + index_t lhs_broadcasting_size, rhs_broadcasting_size; + CheckContiguousAxes_(&lhs_contiguous, &lhs_broadcasting_size, lhs_axes, out_shape); + CheckContiguousAxes_(&rhs_contiguous, &rhs_broadcasting_size, rhs_axes, out_shape); + if (lhs_contiguous && rhs_contiguous && (lhs_axes.ndim() == 0 || rhs_axes.ndim() == 0)) { + *do_opt = true; + if (lhs_axes.ndim() == 0) { + index_t leading = + rhs_shape.ProdShape(0, rhs_axes[0]); + index_t trailing = + rhs_shape.ProdShape(rhs_axes[rhs_axes.ndim() - 1] + 1, rhs_shape.ndim()); + *new_lhs_shape = Shape3(leading, rhs_broadcasting_size, trailing); + *new_rhs_shape = Shape3(leading, 1, trailing); + *new_out_shape = Shape3(leading, rhs_broadcasting_size, trailing); } else { - CHECK(pos[2] == 0 && pos[3] == pos[0] && pos[1] == n) - << "broadcast shape error: lhs = " << lhs << "; rhs = " << rhs; - *lhs_broadcast_axis = 1; - *rhs_broadcast_axis = 0; - p = pos[0]; + index_t leading = + lhs_shape.ProdShape(0, lhs_axes[0]); + index_t trailing = + lhs_shape.ProdShape(lhs_axes[lhs_axes.ndim() - 1] + 1, lhs_shape.ndim()); + *new_lhs_shape = Shape3(leading, 1, trailing); + *new_rhs_shape = Shape3(leading, lhs_broadcasting_size, trailing); + *new_out_shape = Shape3(leading, lhs_broadcasting_size, trailing); } - std::vector dim(2, 1); - for (int i = 0; i < p; ++i) dim[0] *= ret[i]; - for (int i = p; i < n; ++i) dim[1] *= ret[i]; - *ret_reshaped = TShape(dim.begin(), dim.end()); - return; - } - std::vector dim(3, 1); - for (int i = 0; i < 3; ++i) { - for (int j = pos_ordered[i]; j < pos_ordered[i+1]; ++j) { - dim[i] *= ret[j]; + } else { + *do_opt = false; + CHECK(lhs_shape.ndim() <= MXNET_SPECIAL_MAX_NDIM) + << "Only support input dimension up to " << MXNET_SPECIAL_MAX_NDIM + << ", lhs_shape=" << lhs_shape << ", rhs_shape=" << rhs_shape + << ", out_shape=" << out_shape; + *new_lhs_shape = TShape(MXNET_SPECIAL_MAX_NDIM); + *new_rhs_shape = TShape(MXNET_SPECIAL_MAX_NDIM); + *new_out_shape = TShape(MXNET_SPECIAL_MAX_NDIM); + for (int i = 0; i < lhs_shape.ndim(); i++) { + (*new_lhs_shape)[i] = lhs_shape[i]; + (*new_rhs_shape)[i] = rhs_shape[i]; + (*new_out_shape)[i] = out_shape[i]; } } - *ret_reshaped = TShape(dim.begin(), dim.end()); + CHECK(((*new_lhs_shape).Size() == lhs_shape.Size()) + && ((*new_rhs_shape).Size() == rhs_shape.Size()) + && ((*new_out_shape).Size() == out_shape.Size())) + << "new_lhs_shape:" << *new_lhs_shape << ",lhs_shape:" << lhs_shape + << "new_rhs_shape:" << *new_rhs_shape << ",rhs_shape:" << rhs_shape + << "new_out_shape:" << *new_out_shape << ",out_shape:" << out_shape; } - template void BinaryBroadcastForward_(const TBlob& lhs, const TBlob& rhs, @@ -171,94 +137,61 @@ void BinaryBroadcastForward_(const TBlob& lhs, TBlob *ret, OpReqType req, RunContext ctx) { + using namespace mshadow; using namespace mshadow::expr; - using mshadow::Shape; - using mshadow::Shape1; - using mshadow::Tensor; - mshadow::Stream *s = ctx.get_stream(); + Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) << "Binary function only support input/output with the same type"; CHECK_EQ(ret->type_flag_, rhs.type_flag_) << "Binary function only support input/output with the same type"; - + CHECK_EQ(lhs.shape_.ndim(), rhs.shape_.ndim()) << "the ndim of lhs and rhs must be equal," + " shape of lhs=" << lhs.shape_ << " shape of rhs=" << rhs.shape_; if (!IsBroadcastNeeded_(lhs.shape_, rhs.shape_)) { // no broadcast MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { - Tensor out = ret->FlatTo2D(s); - ASSIGN_DISPATCH(out, req, - F(lhs.FlatTo2D(s), - rhs.FlatTo2D(s))); - }); + mshadow::Tensor out = ret->FlatTo2D(s); + ASSIGN_DISPATCH(out, req, + F(lhs.FlatTo2D(s), + rhs.FlatTo2D(s))); + }); return; } - - TShape ret_reshaped; - int lhs_broadcast_axis; - int rhs_broadcast_axis; - GetBroadcastShape_(lhs.shape_, rhs.shape_, &ret_reshaped, - &lhs_broadcast_axis, &rhs_broadcast_axis); + bool do_opt; + TShape lhs_new_shape_, rhs_new_shape_, out_new_shape_; + InferBroadcastNewShapes_(&do_opt, &lhs_new_shape_, &rhs_new_shape_, &out_new_shape_, + lhs.shape_, rhs.shape_, ret->shape_); MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { - if (lhs_broadcast_axis >= 0) { - // broadcast lhs - Tensor mlhs = - lhs.get_with_shape(Shape1(lhs.shape_.Size()), s); - if (rhs_broadcast_axis >= 0) { - // broadcast both - Tensor mrhs = - rhs.get_with_shape(Shape1(rhs.shape_.Size()), s); - - Shape<2> ret_mshape = ret_reshaped.get<2>(); - Tensor out = - ret->get_with_shape(ret_mshape, s); - if (lhs_broadcast_axis == 0) { - ASSIGN_DISPATCH(out, req, - F(broadcast<0>(mlhs, ret_mshape), - broadcast<1>(mrhs, ret_mshape))); - } else { - ASSIGN_DISPATCH(out, req, - F(broadcast<1>(mlhs, ret_mshape), - broadcast<0>(mrhs, ret_mshape))); - } - } else { - // only lhs - Shape<3> ret_mshape = ret_reshaped.get<3>(); - Tensor out = - ret->get_with_shape(ret_mshape, s); - Tensor mrhs = - rhs.get_with_shape(ret_mshape, s); - if (lhs.shape_.Size() == 1) { - ASSIGN_DISPATCH(out, req, - F(broadcast_scalar(mlhs, ret_mshape), mrhs)); - } else { - ASSIGN_DISPATCH(out, req, - F(broadcast<1>(mlhs, ret_mshape), mrhs)); - } - } - } else { - Tensor mrhs = - rhs.get_with_shape(mshadow::Shape1(rhs.shape_.Size()), s); - if (rhs_broadcast_axis >= 0) { - // only rhs - Shape<3> ret_mshape = ret_reshaped.get<3>(); - Tensor out = - ret->get_with_shape(ret_mshape, s); - Tensor mlhs = - lhs.get_with_shape(ret_mshape, s); - if (lhs.shape_.Size() == 1) { - ASSIGN_DISPATCH(out, req, - F(mlhs, broadcast_scalar(mrhs, ret_mshape))); - } else { - ASSIGN_DISPATCH(out, req, - F(mlhs, broadcast<1>(mrhs, ret_mshape))); - } - } else { - LOG(FATAL) << "no broadcast is needed"; - } + if (do_opt) { + Shape<3> lhs_new_shape, rhs_new_shape, out_new_shape; + for (index_t i = 0; i < 3; i++) { + lhs_new_shape[i] = lhs_new_shape_[i]; + rhs_new_shape[i] = rhs_new_shape_[i]; + out_new_shape[i] = out_new_shape_[i]; } - }); + Tensor out = ret->get_with_shape(out_new_shape, s); + Tensor mlhs = lhs.get_with_shape(lhs_new_shape, s); + Tensor mrhs = rhs.get_with_shape(rhs_new_shape, s); + ASSIGN_DISPATCH(out, req, + F(broadcast_to(mlhs, out_new_shape_), broadcast_to(mrhs, out_new_shape_))); + } else { + Shape lhs_new_shape, rhs_new_shape, out_new_shape; + for (index_t i = 0; i < MXNET_SPECIAL_MAX_NDIM; i++) { + lhs_new_shape[i] = lhs_new_shape_[i]; + rhs_new_shape[i] = rhs_new_shape_[i]; + out_new_shape[i] = out_new_shape_[i]; + } + Tensor out = + ret->get_with_shape(out_new_shape, s); + Tensor mlhs = + lhs.get_with_shape(lhs_new_shape, s); + Tensor mrhs = + rhs.get_with_shape(rhs_new_shape, s); + ASSIGN_DISPATCH(out, req, + F(broadcast_to(mlhs, out_new_shape_), broadcast_to(mrhs, out_new_shape_))); + } + }); } - template void BinaryBroadcastBackward_(const OutputGrad& out_grad, const EnvArguments& env, @@ -267,13 +200,16 @@ void BinaryBroadcastBackward_(const OutputGrad& out_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx) { + using namespace mshadow; using namespace mshadow::expr; - using mshadow::Shape; - using mshadow::Shape1; - using mshadow::Shape2; - using mshadow::Tensor; - mshadow::Stream *s = ctx.get_stream(); - + Stream *s = ctx.get_stream(); + CHECK_EQ(out_grad.data.type_flag_, lhs_grad->type_flag_) + << "Binary function only support ingrad/outgrad with the same type"; + CHECK_EQ(out_grad.data.type_flag_, rhs_grad->type_flag_) + << "Binary function only support ingrad/outgrad with the same type"; + CHECK_EQ(rhs_grad->shape_.ndim(), rhs_grad->shape_.ndim()) << + "the ndim of lhs_grad and rhs_grad must be equal," + " shape of lhs_grad=" << lhs_grad->shape_ << " shape of rhs_grad=" << rhs_grad->shape_; if (!IsBroadcastNeeded_(lhs_grad->shape_, rhs_grad->shape_)) { // no broadcast MSHADOW_TYPE_SWITCH(lhs_grad->type_flag_, DType, { @@ -285,63 +221,39 @@ void BinaryBroadcastBackward_(const OutputGrad& out_grad, }); return; } - - TShape ret_reshaped; - int lhs_broadcast_axis; - int rhs_broadcast_axis; - GetBroadcastShape_(lhs_grad->shape_, rhs_grad->shape_, &ret_reshaped, - &lhs_broadcast_axis, &rhs_broadcast_axis); - index_t lhs_size = lhs_grad->shape_.Size(); - index_t rhs_size = rhs_grad->shape_.Size(); - + bool do_opt; + TShape lhs_new_shape_, rhs_new_shape_, out_new_shape_; + InferBroadcastNewShapes_(&do_opt, &lhs_new_shape_, &rhs_new_shape_, &out_new_shape_, + lhs_grad->shape_, rhs_grad->shape_, out_grad.data.shape_); MSHADOW_REAL_TYPE_SWITCH(lhs_grad->type_flag_, DType, { - if (lhs_broadcast_axis >= 0) { - Tensor mlhs_grad = - lhs_grad->get_with_shape(Shape1(lhs_size), s); - if (rhs_broadcast_axis >= 0) { - // broadcast both - Tensor mout_grad = - out_grad.data.get_with_shape(ret_reshaped.get<2>(), s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(Shape1(rhs_size), s); - if (lhs_broadcast_axis == 0) { - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<0>(F(mout_grad))); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<1>(F(mout_grad))); - } else { - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<1>(F(mout_grad))); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<0>(F(mout_grad))); - } - } else { - // only broadcast lhs - Tensor mout_grad = - out_grad.data.get_with_shape(ret_reshaped.get<3>(), s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(ret_reshaped.get<3>(), s); - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<1>(F(mout_grad))); - ASSIGN_DISPATCH(mrhs_grad, req_rhs_grad, F(mout_grad)); - } - } else { - if (rhs_broadcast_axis >= 0) { - // only broadcast rhs - Tensor mlhs_grad = - lhs_grad->get_with_shape(ret_reshaped.get<3>(), s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(Shape1(rhs_size), s); - Tensor mout_grad = - out_grad.data.get_with_shape(ret_reshaped.get<3>(), s); - ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, F(mout_grad)); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<1>(F(mout_grad))); - } else { - LOG(FATAL) << "no broadcast is needed"; - } + if (do_opt) { + Shape<3> out_new_shape; + for (index_t i = 0; i < 3; i++) { + out_new_shape[i] = out_new_shape_[i]; } - }); + Tensor mout_grad = + out_grad.data.get_with_shape(out_new_shape, s); + Tensor mlhs_grad = + lhs_grad->get_with_shape(Shape1(lhs_grad->Size()), s); + Tensor mrhs_grad = + rhs_grad->get_with_shape(Shape1(rhs_grad->Size()), s); + ReduceToAssign(mlhs_grad, req_lhs_grad, lhs_new_shape_, F(mout_grad)); + ReduceToAssign(mrhs_grad, req_rhs_grad, rhs_new_shape_, F(mout_grad)); + } else { + Shape out_new_shape; + for (index_t i = 0; i < MXNET_SPECIAL_MAX_NDIM; i++) { + out_new_shape[i] = out_new_shape_[i]; + } + Tensor mout_grad = + out_grad.data.get_with_shape(out_new_shape, s); + Tensor mlhs_grad = + lhs_grad->get_with_shape(Shape1(lhs_grad->Size()), s); + Tensor mrhs_grad = + rhs_grad->get_with_shape(Shape1(rhs_grad->Size()), s); + ReduceToAssign(mlhs_grad, req_lhs_grad, lhs_new_shape_, F(mout_grad)); + ReduceToAssign(mrhs_grad, req_rhs_grad, rhs_new_shape_, F(mout_grad)); + } + }); } template @@ -354,112 +266,71 @@ void BroadcastMulBackward_(const OutputGrad& out_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx) { + using namespace mshadow; using namespace mshadow::expr; - using mshadow::Shape; - using mshadow::Shape1; - using mshadow::Shape2; - using mshadow::Tensor; - mshadow::Stream *s = ctx.get_stream(); - + Stream *s = ctx.get_stream(); if (!IsBroadcastNeeded_(lhs_grad->shape_, rhs_grad->shape_)) { MSHADOW_TYPE_SWITCH(lhs_grad->type_flag_, DType, { - Tensor mout_grad = out_grad.data.FlatTo2D(s); - Tensor mlhs_data = lhs.data.FlatTo2D(s); - Tensor mrhs_data = rhs.data.FlatTo2D(s); - Tensor mlhs_grad = lhs_grad->FlatTo2D(s); - Tensor mrhs_grad = rhs_grad->FlatTo2D(s); - CHECK_NE(req_rhs_grad, kWriteInplace); - ASSIGN_DISPATCH(mrhs_grad, req_rhs_grad, mlhs_data * mout_grad); - ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, mrhs_data * mout_grad); - }); + mshadow::Tensor mout_grad = out_grad.data.FlatTo2D(s); + mshadow::Tensor mlhs_data = lhs.data.FlatTo2D(s); + mshadow::Tensor mrhs_data = rhs.data.FlatTo2D(s); + mshadow::Tensor mlhs_grad = lhs_grad->FlatTo2D(s); + mshadow::Tensor mrhs_grad = rhs_grad->FlatTo2D(s); + CHECK_NE(req_rhs_grad, kWriteInplace); + ASSIGN_DISPATCH(mrhs_grad, req_rhs_grad, mlhs_data * mout_grad); + ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, mrhs_data * mout_grad); + }); return; } - - TShape ret_reshaped; - int lhs_broadcast_axis; - int rhs_broadcast_axis; - GetBroadcastShape_(lhs_grad->shape_, rhs_grad->shape_, &ret_reshaped, - &lhs_broadcast_axis, &rhs_broadcast_axis); - index_t lhs_size = lhs_grad->shape_.Size(); - index_t rhs_size = rhs_grad->shape_.Size(); - + bool do_opt; + TShape lhs_new_shape_, rhs_new_shape_, out_new_shape_; + InferBroadcastNewShapes_(&do_opt, &lhs_new_shape_, &rhs_new_shape_, &out_new_shape_, + lhs_grad->shape_, rhs_grad->shape_, out_grad.data.shape_); MSHADOW_REAL_TYPE_SWITCH(lhs_grad->type_flag_, DType, { - if (lhs_broadcast_axis >= 0) { - Tensor mlhs_data = - lhs.data.get_with_shape(Shape1(lhs_size), s); - Tensor mlhs_grad = - lhs_grad->get_with_shape(Shape1(lhs_size), s); - - if (rhs_broadcast_axis >= 0) { - // broadcast both - Tensor mout_grad = - out_grad.data.get_with_shape(ret_reshaped.get<2>(), s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(Shape1(rhs_size), s); - Tensor mrhs_data = - rhs.data.get_with_shape(Shape1(rhs_size), s); - if (lhs_broadcast_axis == 0) { - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<0>( - mout_grad * broadcast<1>(mrhs_data, ret_reshaped.get<2>()))); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<1>( - mout_grad * broadcast<0>(mlhs_data, ret_reshaped.get<2>()))); - } else { - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<1>( - mout_grad * broadcast<0>(mrhs_data, ret_reshaped.get<2>()))); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<0>( - mout_grad * broadcast<1>(mlhs_data, ret_reshaped.get<2>()))); - } - } else { - // only broadcast lhs - Tensor mout_grad = - out_grad.data.get_with_shape(ret_reshaped.get<3>(), s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(ret_reshaped.get<3>(), s); - Tensor mrhs_data = - rhs.data.get_with_shape(ret_reshaped.get<3>(), s); - - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<1>(mout_grad * mrhs_data)); - if (lhs_size == 1) { - ASSIGN_DISPATCH(mrhs_grad, req_rhs_grad, - mout_grad * broadcast_scalar(mlhs_data, ret_reshaped.get<3>())); - } else { - ASSIGN_DISPATCH(mrhs_grad, req_rhs_grad, - mout_grad * broadcast<1>(mlhs_data, ret_reshaped.get<3>())); - } - } - } else { - if (rhs_broadcast_axis >= 0) { - // only broadcast rhs - Tensor mlhs_grad = - lhs_grad->get_with_shape(ret_reshaped.get<3>(), s); - Tensor mlhs_data = - lhs.data.get_with_shape(ret_reshaped.get<3>(), s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(Shape1(rhs_size), s); - Tensor mrhs_data = - rhs.data.get_with_shape(Shape1(rhs_size), s); - Tensor mout_grad = - out_grad.data.get_with_shape(ret_reshaped.get<3>(), s); - - if (rhs_size == 1) { - ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, - mout_grad * broadcast_scalar(mrhs_data, ret_reshaped.get<3>())); - } else { - ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, - mout_grad * broadcast<1>(mrhs_data, ret_reshaped.get<3>())); - } - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<1>(mout_grad * mlhs_data)); - } else { - LOG(FATAL) << "no broadcast is needed"; - } + if (do_opt) { + Shape<3> lhs_new_shape, rhs_new_shape, out_new_shape; + for (index_t i = 0; i < 3; i++) { + lhs_new_shape[i] = lhs_new_shape_[i]; + rhs_new_shape[i] = rhs_new_shape_[i]; + out_new_shape[i] = out_new_shape_[i]; } - }); + mshadow::Tensor mout_grad = + out_grad.data.get_with_shape(out_new_shape, s); + mshadow::Tensor mlhs_data = + lhs.data.get_with_shape(lhs_new_shape, s); + mshadow::Tensor mrhs_data = + rhs.data.get_with_shape(rhs_new_shape, s); + mshadow::Tensor mlhs_grad = + lhs_grad->get_with_shape(Shape1(lhs_grad->Size()), s); + mshadow::Tensor mrhs_grad = + rhs_grad->get_with_shape(Shape1(rhs_grad->Size()), s); + ReduceToAssign(mrhs_grad, req_rhs_grad, rhs_new_shape_, + broadcast_to(mlhs_data, out_new_shape_) * mout_grad); + ReduceToAssign(mlhs_grad, req_lhs_grad, lhs_new_shape_, + broadcast_to(mrhs_data, out_new_shape_) * mout_grad); + } else { + Shape lhs_new_shape, rhs_new_shape, out_new_shape; + for (index_t i = 0; i < MXNET_SPECIAL_MAX_NDIM; i++) { + lhs_new_shape[i] = lhs_new_shape_[i]; + rhs_new_shape[i] = rhs_new_shape_[i]; + out_new_shape[i] = out_new_shape_[i]; + } + mshadow::Tensor mout_grad = + out_grad.data.get_with_shape(out_new_shape, s); + mshadow::Tensor mlhs_data = + lhs.data.get_with_shape(lhs_new_shape, s); + mshadow::Tensor mrhs_data = + rhs.data.get_with_shape(rhs_new_shape, s); + mshadow::Tensor mlhs_grad = + lhs_grad->get_with_shape(Shape1(lhs_grad->Size()), s); + mshadow::Tensor mrhs_grad = + rhs_grad->get_with_shape(Shape1(rhs_grad->Size()), s); + ReduceToAssign(mrhs_grad, req_rhs_grad, rhs_new_shape_, + broadcast_to(mlhs_data, out_new_shape_) * mout_grad); + ReduceToAssign(mlhs_grad, req_lhs_grad, lhs_new_shape_, + broadcast_to(mrhs_data, out_new_shape_) * mout_grad); + } + }); } template @@ -472,122 +343,73 @@ void BroadcastDivBackward_(const OutputGrad& out_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx) { + using namespace mshadow; using namespace mshadow::expr; - using mshadow::Shape; - using mshadow::Shape1; - using mshadow::Shape2; - using mshadow::Tensor; - mshadow::Stream *s = ctx.get_stream(); - + Stream *s = ctx.get_stream(); if (!IsBroadcastNeeded_(lhs_grad->shape_, rhs_grad->shape_)) { MSHADOW_TYPE_SWITCH(lhs_grad->type_flag_, DType, { - Tensor mout_grad = out_grad.data.FlatTo2D(s); - Tensor mlhs_data = lhs.data.FlatTo2D(s); - Tensor mrhs_data = rhs.data.FlatTo2D(s); - Tensor mlhs_grad = lhs_grad->FlatTo2D(s); - Tensor mrhs_grad = rhs_grad->FlatTo2D(s); + mshadow::Tensor mout_grad = out_grad.data.FlatTo2D(s); + mshadow::Tensor mlhs_data = lhs.data.FlatTo2D(s); + mshadow::Tensor mrhs_data = rhs.data.FlatTo2D(s); + mshadow::Tensor mlhs_grad = lhs_grad->FlatTo2D(s); + mshadow::Tensor mrhs_grad = rhs_grad->FlatTo2D(s); CHECK_NE(req_rhs_grad, kWriteInplace); ASSIGN_DISPATCH(mrhs_grad, req_rhs_grad, - F(mout_grad * mlhs_data)/ - F(mrhs_data)); - ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, mout_grad / mrhs_data); }); + F(mout_grad * mlhs_data) / + F(mrhs_data)); + ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, mout_grad / mrhs_data); + }); return; } - - TShape ret_reshaped; - int lhs_broadcast_axis; - int rhs_broadcast_axis; - GetBroadcastShape_(lhs_grad->shape_, rhs_grad->shape_, &ret_reshaped, - &lhs_broadcast_axis, &rhs_broadcast_axis); - index_t lhs_size = lhs_grad->shape_.Size(); - index_t rhs_size = rhs_grad->shape_.Size(); - + bool do_opt; + TShape lhs_new_shape_, rhs_new_shape_, out_new_shape_; + InferBroadcastNewShapes_(&do_opt, &lhs_new_shape_, &rhs_new_shape_, &out_new_shape_, + lhs_grad->shape_, rhs_grad->shape_, out_grad.data.shape_); MSHADOW_REAL_TYPE_SWITCH(lhs_grad->type_flag_, DType, { - if (lhs_broadcast_axis >= 0) { - Tensor mlhs_data = - lhs.data.get_with_shape(Shape1(lhs_size), s); - Tensor mlhs_grad = - lhs_grad->get_with_shape(Shape1(lhs_size), s); - - if (rhs_broadcast_axis >= 0) { - // broadcast both - Shape<2> rshape = ret_reshaped.get<2>(); - Tensor mout_grad = - out_grad.data.get_with_shape(rshape, s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(Shape1(rhs_size), s); - Tensor mrhs_data = - rhs.data.get_with_shape(Shape1(rhs_size), s); - if (lhs_broadcast_axis == 0) { - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<0>( - mout_grad / broadcast<1>(mrhs_data, rshape))); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<1>( - F(mout_grad * broadcast<0>(mlhs_data, rshape)) / - F(broadcast<1>(mrhs_data, rshape)))); - } else { - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<1>( - mout_grad / broadcast<0>(mrhs_data, rshape))); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<0>( - F(mout_grad * broadcast<1>(mlhs_data, rshape)) / - F(broadcast<0>(mrhs_data, rshape)))); - } - } else { - // only broadcast lhs - Shape<3> rshape = ret_reshaped.get<3>(); - Tensor mout_grad = - out_grad.data.get_with_shape(rshape, s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(rshape, s); - Tensor mrhs_data = - rhs.data.get_with_shape(rshape, s); - - ASSIGN_DISPATCH( - mlhs_grad, req_lhs_grad, sumall_except_dim<1>(mout_grad / mrhs_data)); - if (lhs_size == 1) { - ASSIGN_DISPATCH(mrhs_grad, req_rhs_grad, - F(mout_grad * broadcast_scalar(mlhs_data, rshape)) / - F(mrhs_data)); - } else { - ASSIGN_DISPATCH(mrhs_grad, req_rhs_grad, - F(mout_grad * broadcast<1>(mlhs_data, rshape)) / - F(mrhs_data)); - } + if (do_opt) { + Shape<3> lhs_new_shape, rhs_new_shape, out_new_shape; + for (index_t i = 0; i < 3; i++) { + lhs_new_shape[i] = lhs_new_shape_[i]; + rhs_new_shape[i] = rhs_new_shape_[i]; + out_new_shape[i] = out_new_shape_[i]; } + mshadow::Tensor mout_grad = + out_grad.data.get_with_shape(out_new_shape, s); + mshadow::Tensor mlhs_data = + lhs.data.get_with_shape(lhs_new_shape, s); + mshadow::Tensor mrhs_data = + rhs.data.get_with_shape(rhs_new_shape, s); + mshadow::Tensor mlhs_grad = + lhs_grad->get_with_shape(Shape1(lhs_grad->Size()), s); + mshadow::Tensor mrhs_grad = + rhs_grad->get_with_shape(Shape1(rhs_grad->Size()), s); + ReduceToAssign(mrhs_grad, req_rhs_grad, rhs_new_shape_, + F(mout_grad * broadcast_to(mlhs_data, out_new_shape_)) / + F(broadcast_to(mrhs_data, out_new_shape_))); + ReduceToAssign(mlhs_grad, req_lhs_grad, lhs_new_shape_, mout_grad / + broadcast_to(mrhs_data, out_new_shape_)); } else { - if (rhs_broadcast_axis >= 0) { - // only broadcast rhs - Shape<3> rshape = ret_reshaped.get<3>(); - Tensor mlhs_grad = lhs_grad->get_with_shape(rshape, s); - Tensor mlhs_data = lhs.data.get_with_shape(rshape, s); - Tensor mrhs_grad = - rhs_grad->get_with_shape(Shape1(rhs_size), s); - Tensor mrhs_data = - rhs.data.get_with_shape(Shape1(rhs_size), s); - Tensor mout_grad = - out_grad.data.get_with_shape(rshape, s); - - if (rhs_size == 1) { - ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, - mout_grad / broadcast_scalar(mrhs_data, rshape)); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<1>( - F(mout_grad * mlhs_data) / - F(broadcast_scalar(mrhs_data, rshape)))); - } else { - ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, - mout_grad / broadcast<1>(mrhs_data, rshape)); - ASSIGN_DISPATCH( - mrhs_grad, req_rhs_grad, sumall_except_dim<1>( - F(mout_grad * mlhs_data) / - F(broadcast<1>(mrhs_data, rshape)))); - } - } else { - LOG(FATAL) << "no broadcast is needed"; + Shape lhs_new_shape, rhs_new_shape, out_new_shape; + for (index_t i = 0; i < MXNET_SPECIAL_MAX_NDIM; i++) { + lhs_new_shape[i] = lhs_new_shape_[i]; + rhs_new_shape[i] = rhs_new_shape_[i]; + out_new_shape[i] = out_new_shape_[i]; } + mshadow::Tensor mout_grad = + out_grad.data.get_with_shape(out_new_shape, s); + mshadow::Tensor mlhs_data = + lhs.data.get_with_shape(lhs_new_shape, s); + mshadow::Tensor mrhs_data = + rhs.data.get_with_shape(rhs_new_shape, s); + mshadow::Tensor mlhs_grad = + lhs_grad->get_with_shape(Shape1(lhs_grad->Size()), s); + mshadow::Tensor mrhs_grad = + rhs_grad->get_with_shape(Shape1(rhs_grad->Size()), s); + ReduceToAssign(mrhs_grad, req_rhs_grad, rhs_new_shape_, + F(mout_grad * broadcast_to(mlhs_data, out_new_shape_)) / + F(broadcast_to(mrhs_data, out_new_shape_))); + ReduceToAssign(mlhs_grad, req_lhs_grad, lhs_new_shape_, mout_grad / + broadcast_to(mrhs_data, out_new_shape_)); } }); } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 8a8049590b34..b0273f288091 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -205,7 +205,7 @@ def test_dot(): assert reldiff(c, C.asnumpy()) < 1e-5 def test_reduce(): - sample_num = 1000 + sample_num = 200 def test_reduce_inner(numpy_reduce_func, nd_reduce_func): for i in range(sample_num): ndim = np.random.randint(1, 8) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 4638f9c905d1..2e34869d92ba 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -733,24 +733,27 @@ def test_convolution_grouping(): np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3) def _gen_broadcast_data(): - testing_shapes = [(2, 3, 4), (3, 5, 7), (4, 2, 6)] - shape_pairs = [] - for n, m, k in testing_shapes: - shape_pairs += [((1,), (1,)), - ((n,), (n,)), - ((n,m), (n,m)), - ((n,m,k), (n,m,k)), - ((n,1), (1,n)), - ((n,m,k), (n,1,1)), - ((n,m,k), (1,m,1)), - ((n,m,k), (1,m,k)), - ((n,m,k), (n,m,1)), - ((n,m,k), (1,1,k))] - shape_pairs += [(v, u) for (u, v) in shape_pairs] - return [(np.random.random(u), np.random.random(v)) for (u,v) in shape_pairs] + # Generate random data that has ndim between 1-7 and all the shape dims between 1-10 + ndim = np.random.randint(1, 8) + shape = np.random.randint(1, 11, size=(ndim,)) + l_same_dim = np.random.randint(0, 5) + r_same_dim = np.random.randint(0, 5) + l_axis_flags = np.random.randint(0, 2, size=ndim) + r_axis_flags = np.random.randint(0, 2, size=ndim) + if l_same_dim == 4: + l_axis_flags = np.ones(ndim) + if r_same_dim == 4: + r_axis_flags = np.ones(ndim) + l_shape = shape.copy() + r_shape = shape.copy() + l_shape[np.where(l_axis_flags == 0)] = 1 + r_shape[np.where(r_axis_flags == 0)] = 1 + return [np.random.random(l_shape), np.random.random(r_shape)] def _check_broadcast_op_forward(symbol, baseline): - for d in _gen_broadcast_data(): + sample_num = 200 + for i in range(sample_num): + d = _gen_broadcast_data() x = baseline(d[0], d[1]) y = symbol.bind(mx.cpu(), args={'a': mx.nd.array(d[0]), 'b' : mx.nd.array(d[1])}) y.forward() @@ -759,8 +762,10 @@ def _check_broadcast_op_forward(symbol, baseline): err, d[0].shape, d[1].shape) def _check_broadcast_op_backward(symbol, baseline): - for d in _gen_broadcast_data(): - out = d[0] + d[1] + sample_num = 200 + for i in range(sample_num): + d = _gen_broadcast_data() + out = np.random.random((d[0] + d[1]).shape) def reduce_op(shape, x): if shape == x.shape: return x @@ -782,7 +787,7 @@ def reduce_op(shape, x): err = lambda x, y: np.sum(np.abs(x-y)) / np.sum(np.abs(x)) err_1 = err(x_1, y_1.asnumpy()) err_2 = err(x_2, y_2.asnumpy()) - assert err_1 < 1e-6 and err_2 < 1e-6, 'lhs error %f, rhs error %f, shapes are %s %s' % ( + assert err_1 < 1e-5 and err_2 < 1e-5, 'lhs error %f, rhs error %f, shapes are %s %s' % ( err_1, err_2, d[0].shape, d[1].shape) def test_broadcast_binary_op(): @@ -927,7 +932,7 @@ def test_reshape_new(src_shape, shape_args, dst_shape): assert(output_shape[0] == (2, 75)) def test_reduce(): - sample_num = 1000 + sample_num = 200 def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym): for i in range(sample_num): # Generate random data that has ndim between 1-7 and all the shape dims between 1-10 @@ -969,7 +974,7 @@ def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym): mx.symbol.sum) def test_broadcast(): - sample_num = 1000 + sample_num = 200 def test_broadcast_axis(): for i in range(sample_num): # Generate random data that has ndim between 1-7 and all the shape dims between 1-10