Skip to content

Commit

Permalink
[OP] Support softmax with probability label (apache#2456)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jun 17, 2016
1 parent e721aa5 commit 8d57211
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@ scala-package/*/*/target/
*.project
*.settings
!scala-package/*/bin
*.bak
4 changes: 2 additions & 2 deletions src/operator/broadcast_reduce_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace op {
* \brief Check if the axes are continuous + get reducing size. E.g (1, 3) -> false, (1,2,3) -> true
* \param is_contiguous_axes whether the axes is contiguous
* \param reducing_size product of source shape in the given axes
* \param axes
* \param axes
* \param src_shape shape of the source tensor
*/
inline void CheckContiguousAxes_(bool *is_contiguous_axes, index_t *reducing_size,
Expand All @@ -45,7 +45,7 @@ inline TShape GetBroadcastingAxes_(const mshadow::TShape &src_shape,
const mshadow::TShape &target_shape) {
std::vector<index_t> axes_vec;
CHECK_EQ(target_shape.ndim(), src_shape.ndim());
for (int i = 0; i < src_shape.ndim(); ++i) {
for (index_t i = 0; i < src_shape.ndim(); ++i) {
if (src_shape[i] != target_shape[i]) {
CHECK_EQ(src_shape[i], 1) << "broadcastsing axis must have size 1, received src_shape="
<< src_shape << " target_shape=" << target_shape;
Expand Down
2 changes: 1 addition & 1 deletion src/operator/elementwise_binary_broadcast_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ inline void InferBroadcastNewShapes_(bool *do_opt,
*new_lhs_shape = TShape(MXNET_SPECIAL_MAX_NDIM);
*new_rhs_shape = TShape(MXNET_SPECIAL_MAX_NDIM);
*new_out_shape = TShape(MXNET_SPECIAL_MAX_NDIM);
for (int i = 0; i < lhs_shape.ndim(); i++) {
for (index_t i = 0; i < lhs_shape.ndim(); i++) {
(*new_lhs_shape)[i] = lhs_shape[i];
(*new_rhs_shape)[i] = rhs_shape[i];
(*new_out_shape)[i] = out_shape[i];
Expand Down
29 changes: 20 additions & 9 deletions src/operator/softmax_output-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ class SoftmaxOutputOp : public Operator {
CHECK_GE(req.size(), 1);
Stream<xpu> *s = ctx.get_stream<xpu>();

if (param_.multi_output) {
if (out_data[softmaxout_enum::kOut].shape_ ==
in_data[softmaxout_enum::kLabel].shape_) {
// use probability as label
Tensor<xpu, 2, DType> label = in_data[softmaxout_enum::kLabel].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = out_data[softmaxout_enum::kOut].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = in_grad[softmaxout_enum::kData].FlatTo2D<xpu, DType>(s);
grad = (out - label) * scalar<DType>(param_.grad_scale);
} else if (param_.multi_output) {
int n = out_data[softmaxout_enum::kOut].size(0);
int k = out_data[softmaxout_enum::kOut].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<int>(out_data[softmaxout_enum::kOut].Size()/n/k));
Expand Down Expand Up @@ -204,14 +211,18 @@ class SoftmaxOutputProp : public OperatorProperty {
CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]";
const TShape &dshape = in_shape->at(0);
if (dshape.ndim() == 0) return false;
if (param_.multi_output) {
SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel,
Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]));
} else {
TShape label_shape(dshape.ndim() - 1);
for (index_t i = 0; i + 1 < dshape.ndim(); ++i)
label_shape[i] = dshape[i];
SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape);

// label.shape == data.shape: use probability as label
if (dshape != (*in_shape)[softmaxout_enum::kLabel]) {
if (param_.multi_output) {
SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel,
Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]));
} else {
TShape label_shape(dshape.ndim() - 1);
for (index_t i = 0; i + 1 < dshape.ndim(); ++i)
label_shape[i] = dshape[i];
SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape);
}
}
out_shape->clear();
out_shape->push_back(dshape);
Expand Down
4 changes: 2 additions & 2 deletions src/operator/softmax_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ DMLC_REGISTER_PARAMETER(SoftmaxOutputParam);
MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp)
.describe("Perform a softmax transformation on input, backprop with logloss.")
.add_argument("data", "Symbol", "Input data to softmax.")
.add_argument("label", "Symbol", "Label data.")
.add_argument("label", "Symbol", "Label data, can also be "\
"probability value with same shape as data")
.add_arguments(SoftmaxOutputParam::__FIELDS__());

MXNET_REGISTER_OP_PROPERTY(Softmax, DeprecatedSoftmaxProp)
Expand All @@ -42,4 +43,3 @@ MXNET_REGISTER_OP_PROPERTY(Softmax, DeprecatedSoftmaxProp)

} // namespace op
} // namespace mxnet

25 changes: 17 additions & 8 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
def same(a, b):
return np.sum(a != b) == 0

def np_softmax(x):
x = x - np.max(x, axis=1).reshape(x.shape[0], 1)
x = np.exp(x)
x /= np.sum(x, axis=1).reshape(x.shape[0], 1)
return x


def check_elementwise_sum_with_shape(shape, n):
# forward
inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
Expand Down Expand Up @@ -235,20 +242,23 @@ def check_softmax_with_ignore_label(xpu):
assert(reldiff(grad0[int(shape[0]/2):], grad1[int(shape[0]/2):]) < 1e-5)

def check_softmax_with_shape(shape, xpu):
# bind with label
X = mx.symbol.Variable('X')
L = mx.symbol.Variable('L')
Y = mx.symbol.SoftmaxOutput(data=X, label=L)
x = mx.random.uniform(-1, 1, shape, ctx = xpu)
l = mx.nd.empty((shape[0],), ctx = xpu)
l[:] = np.random.randint(0, shape[1]-1, (shape[0],))
l = mx.random.uniform(-1, 1, shape, ctx = xpu)
l[:] = np_softmax(l.asnumpy())
grad = mx.nd.empty(shape, ctx = xpu)

exec1 = Y.bind(xpu, args = [x, l], args_grad = {'X': grad})
print('foward')
exec1.forward()
print(exec1.outputs[0].asnumpy())
out = exec1.outputs[0].asnumpy()
assert_allclose(out, np_softmax(x.asnumpy()))
exec1.backward()
print(grad.asnumpy())
assert_allclose(grad.asnumpy(), np_softmax(x.asnumpy()) - l.asnumpy())

def test_softmax():
check_softmax_with_shape((3, 4), mx.cpu())

def check_multi_softmax_with_shape(shape, xpu):
X = mx.symbol.Variable('X')
Expand Down Expand Up @@ -1047,6 +1057,7 @@ def test_flip():
assert_allclose(x.asnumpy()[idx], y.asnumpy())

if __name__ == '__main__':
test_softmax()
test_broadcast_binary_op()
test_flip()
test_crop()
Expand Down Expand Up @@ -1077,5 +1088,3 @@ def test_flip():
test_reshape()
test_reduce()
test_broadcast()
#check_softmax_with_shape((3,4), mx.cpu())
#check_multi_softmax_with_shape((3,4,5), mx.cpu())

0 comments on commit 8d57211

Please sign in to comment.