Skip to content

Commit

Permalink
Support ndim up to 7 for binary broadcasting operators + Accelerate r…
Browse files Browse the repository at this point in the history
…educing OPs by calling reduce_except_dim if possible. + Add `/bigobj` to CMakeList (apache#2418)

Reshape the lhs and rhs to ndim=3 if possible otherwise reshape them into
ndim=7.
  • Loading branch information
sxjscience authored and tqchen committed Jun 17, 2016
1 parent 887491d commit 4b88c19
Show file tree
Hide file tree
Showing 7 changed files with 532 additions and 524 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ if(MSVC)
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-DMXNET_EXPORTS)
set(CMAKE_C_FLAGS "/MP")
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} /bigobj")
else(MSVC)
include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11)
Expand Down
46 changes: 6 additions & 40 deletions include/mxnet/operator_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#ifndef MXNET_OPERATOR_UTIL_H_
#define MXNET_OPERATOR_UTIL_H_

#ifdef _MSC_VER
#pragma warning(disable:4503) // disable warning: decorated name length exceeded.
#endif

#include <dmlc/registry.h>
#include <dmlc/parameter.h>
#include <map>
Expand Down Expand Up @@ -412,47 +416,9 @@ class SimpleOpRegistry {
}

/*!
* \brief cast dynamic range variable into static variable
* \param var the source value, constrained to be between 1 and 5
* \param NDIM the const NDIM that can be used in the template
* \brief Maximum ndim supported for special operators like broadcasting with non contiguous lhs/rhs
*/
#define MXNET_RANGE_SWITCH(var, NDIM, ...) \
{ \
switch (var) { \
case 1: \
{ \
static const int NDIM = 1; \
{__VA_ARGS__} \
} \
break; \
case 2: \
{ \
static const int NDIM = 2; \
{__VA_ARGS__} \
} \
break; \
case 3: \
{ \
static const int NDIM = 3; \
{__VA_ARGS__} \
} \
break; \
case 4: \
{ \
static const int NDIM = 4; \
{__VA_ARGS__} \
} \
break; \
case 5: \
{ \
static const int NDIM = 5; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Only support ndim=1 to 5."; \
} \
}
#define MXNET_SPECIAL_MAX_NDIM 7


//--------------------------------------------------------------
Expand Down
116 changes: 83 additions & 33 deletions src/operator/broadcast_reduce_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,14 @@ void L2Norm(const TBlob &src,
OpReqType req,
RunContext ctx) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
mshadow::Tensor<xpu, 1> out = ret->get<xpu, 1, real_t>(s);
mshadow::Tensor<xpu, 1> in =
src.get_with_shape<xpu, 1, real_t>(mshadow::Shape1(src.shape_.Size()), s);
mshadow::VectorDot(out, in, in);
out = mshadow::expr::F<mxnet::op::mshadow_op::square_root>(out);
CHECK_EQ(src.type_flag_, ret->type_flag_);
MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, {
mshadow::Tensor<xpu, 1, DType> out = ret->get<xpu, 1, DType>(s);
mshadow::Tensor<xpu, 1, DType> in =
src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(src.shape_.Size()), s);
mshadow::VectorDot(out, in, in);
ASSIGN_DISPATCH(out, req, mshadow::expr::F<mxnet::op::mshadow_op::square_root>(out));
});
}

template<typename xpu, typename Reducer>
Expand All @@ -117,10 +120,13 @@ void Reduce(const TBlob &src,
OpReqType req,
RunContext ctx) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
mshadow::Tensor<xpu, 1> out = ret->get<xpu, 1, real_t>(s);
mshadow::Tensor<xpu, 2> in =
src.get_with_shape<xpu, 2, real_t>(mshadow::Shape2(1, src.shape_.Size()), s);
out = mshadow::expr::reduce_except_dim<0, Reducer>(in);
CHECK_EQ(src.type_flag_, ret->type_flag_);
MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, {
mshadow::Tensor<xpu, 1, DType> out = ret->get<xpu, 1, DType>(s);
mshadow::Tensor<xpu, 2, DType> in =
src.get_with_shape<xpu, 2, DType>(mshadow::Shape2(1, src.shape_.Size()), s);
ASSIGN_DISPATCH(out, req, (mshadow::expr::reduce_except_dim<0, Reducer>(in)));
});
}

// backward function that takes input value of the op
Expand All @@ -135,15 +141,15 @@ void SumBackward_(const OutputGrad& scale,
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(in_grad->type_flag_, scale.data.type_flag_)
<< "Unary function only support input/output with the same type";
MSHADOW_TYPE_SWITCH(in_grad->type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(in_grad->type_flag_, DType, {
mshadow::Tensor<xpu, 1, DType> mscale = scale.data.get<xpu, 1, DType>(s);
mshadow::Tensor<xpu, 2, DType> igrad = in_grad->FlatTo2D<xpu, DType>(s);
ASSIGN_DISPATCH(igrad, req,
broadcast_scalar(mscale, igrad.shape_));
});
}

template<typename xpu, typename Reducer, bool get_mask>
template<typename xpu, typename Reducer>
void ReduceChannel(const TBlob &src,
const EnvArguments& env,
TBlob *ret,
Expand All @@ -153,13 +159,17 @@ void ReduceChannel(const TBlob &src,
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 2> out = ret->get_with_shape<xpu, 2, real_t>(
Shape2(src.shape_[0], src.Size()/src.shape_[0]/src.shape_[1]),
s);
Tensor<xpu, 3> in = src.get_with_shape<xpu, 3, real_t>(
Shape3(src.shape_[0], src.shape_[1], src.Size()/src.shape_[0]/src.shape_[1]),
CHECK_EQ(src.type_flag_, ret->type_flag_);
MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, {
Tensor<xpu, 2, DType> out = ret->get_with_shape<xpu, 2, DType>(
Shape2(src.shape_[0], src.Size() / src.shape_[0] / src.shape_[1]),
s);
out = reduce_with_axis<Reducer, get_mask>(in, 1);
Tensor<xpu, 3, DType> in = src.get_with_shape<xpu, 3, DType>(
Shape3(src.shape_[0], src.shape_[1], src.Size() / src.shape_[0] / src.shape_[1]),
s);
CHECK(req != kAddTo) << "AddTo is not supported";
ASSIGN_DISPATCH(out, req, (reduce_with_axis<Reducer, true>(in, 1)));
});
}

// return a shape of ReduceChannel output
Expand All @@ -184,13 +194,16 @@ void ReduceAxisImpl_(const TBlob &src,
bool keepdims) {
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(src.type_flag_, ret->type_flag_);
if (-1 == axis) {
// Reduce all dimensions if axis == -1
mshadow::Tensor<xpu, 2> in =
src.get_with_shape<xpu, 2, real_t>(mshadow::Shape2(1, src.shape_.Size()), s);
mshadow::Tensor<xpu, 1> out =
ret->get_with_shape<xpu, 1, real_t>(mshadow::Shape1(ret->shape_.Size()), s);
out = mshadow::expr::reduce_except_dim<0, Reducer>(in);
MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, {
mshadow::Tensor<xpu, 2, DType> in =
src.get_with_shape<xpu, 2, DType>(mshadow::Shape2(src.shape_.Size(), 1), s);
mshadow::Tensor<xpu, 1, DType> out =
ret->get_with_shape<xpu, 1, DType>(mshadow::Shape1(ret->shape_.Size()), s);
ASSIGN_DISPATCH(out, req, (reduce_except_dim<1, Reducer>(in)));
});
return;
}
int trailing = 1;
Expand All @@ -202,11 +215,46 @@ void ReduceAxisImpl_(const TBlob &src,
trailing *= src.shape_[i];
}
}
mshadow::Tensor<xpu, 3> in =
src.get_with_shape<xpu, 3, real_t>(mshadow::Shape3(leading, src.shape_[axis], trailing), s);
mshadow::Tensor<xpu, 2> out =
ret->get_with_shape<xpu, 2, real_t>(mshadow::Shape2(leading, trailing), s);
out = mshadow::expr::reduce_with_axis<Reducer, get_mask>(in, 1);
if (get_mask) {
// If get_mask is on, we have to use the slower `reduce_with_axis`
// since reduce_except_dim does not support the flag.
MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, {
mshadow::Tensor<xpu, 3, DType> in =
src.get_with_shape<xpu, 3, DType>(mshadow::Shape3(leading, src.shape_[axis], trailing), s);
mshadow::Tensor<xpu, 2, DType> out =
ret->get_with_shape<xpu, 2, DType>(mshadow::Shape2(leading, trailing), s);
CHECK(req != kAddTo) << "AddTo is not supported for `get_mask = true`";
ASSIGN_DISPATCH(out, req, (reduce_with_axis<Reducer, true>(in, 1)));
});
return;
}
if (1 == leading) {
MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, {
mshadow::Tensor<xpu, 2, DType> in =
src.get_with_shape<xpu, 2, DType>(mshadow::Shape2(src.shape_[axis], trailing), s);
mshadow::Tensor<xpu, 1, DType> out =
ret->get_with_shape<xpu, 1, DType>(mshadow::Shape1(trailing), s);
ASSIGN_DISPATCH(out, req, (reduce_except_dim<1, Reducer>(in)));
});
} else if (1 == trailing) {
MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, {
mshadow::Tensor<xpu, 2, DType> in =
src.get_with_shape<xpu, 2, DType>(mshadow::Shape2(leading, src.shape_[axis]), s);
mshadow::Tensor<xpu, 1, DType> out =
ret->get_with_shape<xpu, 1, DType>(mshadow::Shape1(leading), s);
ASSIGN_DISPATCH(out, req, (reduce_except_dim<1, Reducer>(in.T())));
});
} else {
MSHADOW_REAL_TYPE_SWITCH(src.type_flag_, DType, {
mshadow::Tensor<xpu, 3, DType> in =
src.get_with_shape<xpu, 3, DType>(mshadow::Shape3(leading, src.shape_[axis], trailing), s);
mshadow::Tensor<xpu, 1, DType> out =
ret->get_with_shape<xpu, 1, DType>(mshadow::Shape1(leading * trailing), s);
ASSIGN_DISPATCH(out, req,
(reduce_except_dim<1, Reducer>(reshape(swapaxis<1, 0>(in),
mshadow::Shape2(src.shape_[axis], leading * trailing)))));
});
}
}

// Broadcast the given axis to the given broadcasting size
Expand Down Expand Up @@ -240,11 +288,13 @@ void BroadcastAxisImpl_(const TBlob &src,
trailing *= ret->shape_[i];
}
}
mshadow::Tensor<xpu, 2> in =
src.get_with_shape<xpu, 2, real_t>(mshadow::Shape2(leading, trailing), s);
mshadow::Tensor<xpu, 3> out =
ret->get_with_shape<xpu, 3, real_t>(mshadow::Shape3(leading, bsize, trailing), s);
out = mshadow::expr::broadcast_with_axis(in, 0, bsize);
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
mshadow::Tensor<xpu, 2, DType> in =
src.get_with_shape<xpu, 2, DType>(mshadow::Shape2(leading, trailing), s);
mshadow::Tensor<xpu, 3, DType> out =
ret->get_with_shape<xpu, 3, DType>(mshadow::Shape3(leading, bsize, trailing), s);
ASSIGN_DISPATCH(out, req, broadcast_with_axis(in, 0, bsize));
});
}

// Forward pass of reduce over the given axis
Expand Down Expand Up @@ -386,7 +436,7 @@ MXNET_REGISTER_SIMPLE_OP(sum_axis, XPU)

// argmax channel
MXNET_REGISTER_SIMPLE_OP(argmax_channel, XPU)
.set_function(XPU::kDevMask, ReduceChannel<XPU, mshadow::red::maximum, true>,
.set_function(XPU::kDevMask, ReduceChannel<XPU, mshadow::red::maximum>,
kNoInplace, kNotRegisterSymbolic)
.set_shape_function(ReduceChannelShape)
.describe("Take argmax indices of each channel of the src."
Expand Down
Loading

0 comments on commit 4b88c19

Please sign in to comment.