Skip to content

Commit

Permalink
!45221 [MS][OP]sparse segment sum /grad cpu support dynamic shape
Browse files Browse the repository at this point in the history
Merge pull request !45221 from mengyuanli/ds_sparse_segment_sum
  • Loading branch information
it-is-a-robot authored and gitee-org committed Nov 8, 2022
2 parents 477bb05 + 28ee026 commit 1088b85
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,113 +19,32 @@

namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseSegmentSumInputsNum = 3;
constexpr size_t kSparseSegmentSumOutputsNum = 1;

#define ADD_KERNEL(t1, t2, t3, t4) \
KernelAttr() \
.AddInputAttr(kNumberType##t1) \
.AddInputAttr(kNumberType##t2) \
.AddInputAttr(kNumberType##t3) \
.AddOutputAttr(kNumberType##t4)
} // namespace

void SparseSegmentSumCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
segment_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
bool SparseSegmentSumCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
constexpr size_t input_num = 3;
constexpr size_t output_num = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
return MatchKernelFunc(base_operator, inputs, outputs);
}

bool SparseSegmentSumCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
switch (x_dtype_) {
case (kNumberTypeInt8):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<int8_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<int8_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeInt16):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<int16_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<int16_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeInt32):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<int32_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<int32_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeInt64):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<int64_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<int64_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeUInt8):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<uint8_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<uint8_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeUInt16):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<uint16_t, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<uint16_t, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeFloat16):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<float16, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<float16, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeFloat32):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<float, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<float, int64_t>(inputs, outputs);
break;
}
case (kNumberTypeFloat64):
if (indices_dtype_ == kNumberTypeInt32) {
LaunchKernel<double, int32_t>(inputs, outputs);
break;
} else {
LaunchKernel<double, int64_t>(inputs, outputs);
break;
}
default:
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(x_dtype_)
<< " which is not supported.";
int SparseSegmentSumCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
return true;
x_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
segment_shape_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
return KRET_OK;
}

template <typename T1, typename T2>
void SparseSegmentSumCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
bool SparseSegmentSumCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
constexpr size_t kMultiply = 1;
size_t n = std::accumulate(x_shape_.begin(), x_shape_.end(), kMultiply, std::multiplies<int>()) / x_shape_[kIndex0];
Expand Down Expand Up @@ -163,27 +82,32 @@ void SparseSegmentSumCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &i
yptr[j + oldindex * n] += dataptr[j + indicesptr[i] * n];
}
}
return true;
}

void SparseSegmentSumCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentSumInputsNum, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentSumOutputsNum, kernel_name_);
}
#define SPARSE_SEGMENT_SUM_CPU_REG(MS_T, MS_S, T, S) \
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddInputAttr(MS_S).AddOutputAttr(MS_T), \
&SparseSegmentSumCpuKernelMod::LaunchKernel<T, S>

#define SPARSE_SEGMENT_SUM_CPU_INDEX_REG(MS_T, T) \
{SPARSE_SEGMENT_SUM_CPU_REG(MS_T, kNumberTypeInt32, T, int32_t)}, { \
SPARSE_SEGMENT_SUM_CPU_REG(MS_T, kNumberTypeInt64, T, int64_t) \
}

std::vector<KernelAttr> SparseSegmentSumCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
ADD_KERNEL(Int8, Int32, Int32, Int8), ADD_KERNEL(Float16, Int32, Int32, Float16),
ADD_KERNEL(Int16, Int32, Int32, Int16), ADD_KERNEL(Float32, Int32, Int32, Float32),
ADD_KERNEL(Int32, Int32, Int32, Int32), ADD_KERNEL(Float64, Int32, Int32, Float64),
ADD_KERNEL(Int64, Int32, Int32, Int64), ADD_KERNEL(UInt8, Int32, Int32, UInt8),
ADD_KERNEL(UInt16, Int32, Int32, UInt16), ADD_KERNEL(Int8, Int64, Int64, Int8),
ADD_KERNEL(Float16, Int64, Int64, Float16), ADD_KERNEL(Int16, Int64, Int64, Int16),
ADD_KERNEL(Float32, Int64, Int64, Float32), ADD_KERNEL(Int32, Int64, Int64, Int32),
ADD_KERNEL(Float64, Int64, Int64, Float64), ADD_KERNEL(Int64, Int64, Int64, Int64),
ADD_KERNEL(UInt8, Int64, Int64, UInt8), ADD_KERNEL(UInt16, Int64, Int64, UInt16)};
return kernel_attr_list;
const std::vector<std::pair<KernelAttr, SparseSegmentSumCpuKernelMod::KernelRunFunc>>
&SparseSegmentSumCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, SparseSegmentSumCpuKernelMod::KernelRunFunc>> func_list = {
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeInt8, int8_t),
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeInt16, int16_t),
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeInt32, int32_t),
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeInt64, int64_t),
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeUInt8, uint8_t),
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeUInt16, uint16_t),
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeFloat16, float16),
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeFloat32, float),
SPARSE_SEGMENT_SUM_CPU_INDEX_REG(kNumberTypeFloat64, double),
};
return func_list;
}

MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentSum, SparseSegmentSumCpuKernelMod);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,52 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SUM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SUM_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SPARSE_SEGMENT_SUM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SPARSE_SEGMENT_SUM_CPU_KERNEL_H_
#include <functional>
#include <numeric>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"

namespace mindspore {
namespace kernel {
class SparseSegmentSumCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class SparseSegmentSumCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<SparseSegmentSumCpuKernelMod> {
public:
SparseSegmentSumCpuKernelMod() = default;

~SparseSegmentSumCpuKernelMod() override = default;

void InitKernel(const CNodePtr &kernel_node) override;

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs);
}

template <typename T1, typename T2>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;

int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;

const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;

protected:
std::vector<KernelAttr> GetOpSupport() override;
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }

template <typename T1, typename T2>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);

private:
void CheckParam(const CNodePtr &kernel_node);
ShapeVector x_shape_;
ShapeVector segment_shape_;
TypeId x_dtype_{kTypeUnknown};
TypeId indices_dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SUM_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SPARSE_SEGMENT_SUM_CPU_KERNEL_H_
8 changes: 5 additions & 3 deletions mindspore/core/ops/grad/sparse_segment_sum_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ abstract::ShapePtr SparseSegmentSumGradInferShape(const PrimitivePtr &prim,
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto output_dim0_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
// support dynamic rank
if (IsDynamicRank(grad_shape) || IsDynamicRank(indices_shape) || IsDynamicRank(segment_ids_shape) ||
IsDynamicRank(output_dim0_shape)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
prim->name());
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
Expand All @@ -53,9 +58,6 @@ abstract::ShapePtr SparseSegmentSumGradInferShape(const PrimitivePtr &prim,
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
}
if (IsDynamicRank(grad_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
}
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
auto output_dim0_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
Expand Down
7 changes: 4 additions & 3 deletions mindspore/core/ops/sparse_segment_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ abstract::ShapePtr SparseSegmentSumInferShape(const PrimitivePtr &prim,
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
// support dynamic rank
if (IsDynamicRank(x_shape) || IsDynamicRank(indices_shape) || IsDynamicRank(segment_ids_shape)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
prim->name());
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
Expand All @@ -44,9 +48,6 @@ abstract::ShapePtr SparseSegmentSumInferShape(const PrimitivePtr &prim,
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
}
if (IsDynamicRank(x_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
}
if (!input_args[kInputIndex2]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex2]->BuildValue()->isa<None>()) {
auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue();
Expand Down

0 comments on commit 1088b85

Please sign in to comment.