Skip to content

Commit

Permalink
Fix onnx.GatherND and onnx.ScatterND issues with dynamic indices (onn…
Browse files Browse the repository at this point in the history
…x#2550)

* Support onnx.GatherND, onnx.ScatterND and onnx.Gather ops with dynamic indices.

Signed-off-by: Yasushi Negishi <[email protected]>
  • Loading branch information
negiyas authored Oct 26, 2023
1 parent 4c07886 commit ee36a16
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 50 deletions.
80 changes: 33 additions & 47 deletions src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,68 +64,60 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern<ONNXGatherNDOp> {
Value data = adaptor.getData();
Value indices = adaptor.getIndices();
int64_t b = adaptor.getBatchDims();
auto indicesType = indices.getType().cast<ShapedType>();
DimsExpr dataDims, indicesDims;
create.krnlIE.getShapeAsDims(data, dataDims);
create.krnlIE.getShapeAsDims(indices, indicesDims);
auto dataType = data.getType().cast<ShapedType>();
ArrayRef<int64_t> indicesShape = indicesType.getShape();
ArrayRef<int64_t> dataShape = dataType.getShape();
int64_t dataRank = dataShape.size();
int64_t indicesRank = indicesShape.size();
int64_t dataRank = dataDims.size();
int64_t indicesRank = indicesDims.size();
auto indicesType = indices.getType().cast<ShapedType>();
ArrayRef<int64_t> indicesShape = indicesType.getShape();
int64_t indicesLastDim = indicesShape[indicesRank - 1];
assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) &&
"indices.shape[-1] must be in the range [1, dataRank - b]");

// Convert the output type to MemRefType.
Type convertedType = typeConverter->convertType(*op->result_type_begin());
assert(convertedType && convertedType.isa<MemRefType>() &&
"Failed to convert type to MemRefType");
MemRefType outputMemRefType = convertedType.cast<MemRefType>();
ArrayRef<int64_t> outputShape = outputMemRefType.getShape();
int64_t outputRank = outputShape.size();

// Ensure the operation constains are satisfied.
assert(dataRank >= 1 && "The rank of 'data' must be >= 1");
assert(indicesRank >= 1 && "The rank of 'indices' must be >= 1");
assert((outputRank == dataRank + indicesRank - indicesLastDim - 1 - b) &&
"Incorrect outut rank");
assert(b >= 0 && "batch_dim should not be negative");
assert(b < std::min(dataRank, indicesRank) &&
"batch_dims must be smaller than the min(dataRank, indicesRank)");
assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) &&
"indices.shape[-1] must be in the range [1, dataRank - b]");
DimsExpr outputDims = shapeHelper.getOutputDims();

// Reshape 'indices' to the 3D shape:
// [batchDimSize, indicesDimsSize, indices.shape[-1]].
const int64_t batchDimsSize = std::accumulate(indicesShape.begin(),
indicesShape.begin() + b, 1, std::multiplies<int64_t>());
const int64_t indicesDimsSize = std::accumulate(indicesShape.begin(),
indicesShape.end(), 1, std::multiplies<int64_t>());
assert(batchDimsSize >= 0 && "batchDimsSize must be non-negative");
assert(indicesDimsSize >= 0 && "indicesDimsSize must be non-negative");

LiteralIndexExpr BDS(batchDimsSize),
IDS(indicesDimsSize / (batchDimsSize * indicesLastDim)),
ILD(indicesLastDim);
LiteralIndexExpr oneIE(1);
IndexExpr batchDimsSize = oneIE;
for (int64_t i = 0; i < b; i++)
batchDimsSize = batchDimsSize * indicesDims[i];
IndexExpr indicesDimsSize = oneIE;
for (int64_t i = b; i < indicesRank - 1; i++)
indicesDimsSize = indicesDimsSize * indicesDims[i];
IndexExpr BDS(batchDimsSize), IDS(indicesDimsSize);
LiteralIndexExpr ILD(indicesLastDim);
DimsExpr newIndicesShape = {BDS, IDS, ILD};
Value reshapedIndices =
create.mem.reinterpretCast(indices, newIndicesShape);
LLVM_DEBUG(llvm::dbgs() << "reshapedIndices: " << reshapedIndices << "\n");

// Reshape 'data' to shape [batchDimSize, data.shape[b:]]
DimsExpr newDataShape = {BDS};
DimsExpr newDataDims = {BDS};
for (int64_t i = b; i < dataRank; ++i) {
assert(dataShape[i] != ShapedType::kDynamic &&
"Cannot support data with dynamic dimensions");
LiteralIndexExpr dataDim(dataShape[i]);
newDataShape.emplace_back(dataDim);
newDataDims.emplace_back(dataDims[i]);
}
int64_t reshapedDataRank = newDataShape.size();
Value reshapedData = create.mem.reinterpretCast(data, newDataShape);
int64_t reshapedDataRank = newDataDims.size();
Value reshapedData = create.mem.reinterpretCast(data, newDataDims);
LLVM_DEBUG(llvm::dbgs() << "reshapedData: " << reshapedData << "\n");

// Allocate a 1D output buffer.
const int64_t outputDimsSize = std::accumulate(
outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
Value outputDataBuffer = create.mem.alloc(
MemRefType::get({outputDimsSize}, outputMemRefType.getElementType()));

IndexExpr outputDimsSize = oneIE;
for (uint64_t i = 0; i < outputDims.size(); i++)
outputDimsSize = outputDimsSize * outputDims[i];
SmallVector<IndexExpr> outputIndexExpr = {outputDimsSize};
int64_t dim = outputDimsSize.isLiteral() ? outputDimsSize.getLiteral()
: ShapedType::kDynamic;
Type outputType = dataType.getElementType();
Value outputDataBuffer =
create.mem.alloc(MemRefType::get({dim}, outputType), outputIndexExpr);
// Initialize the index used to store the result values.
Value iZero = create.math.constantIndex(0);
Value iOne = create.math.constantIndex(1);
Expand Down Expand Up @@ -247,14 +239,8 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern<ONNXGatherNDOp> {
});

// Finally reshape 'outputDataBuffer' to the shape of the output.
DimsExpr newOutputShape;
for (int64_t dim : outputShape) {
LiteralIndexExpr outputDim(dim);
newOutputShape.emplace_back(outputDim);
}

Value reshapedOutput =
create.mem.reinterpretCast(outputDataBuffer, newOutputShape);
create.mem.reinterpretCast(outputDataBuffer, outputDims);
LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n");

rewriter.replaceOp(op, reshapedOutput);
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern<ONNXScatterNDOp> {
IndexExpr index = NonAffineIndexExpr(indexVal);
outputAccessFct.emplace_back(index);
} else {
IndexExpr index = SymbolIndexExpr(loopInd[i]);
IndexExpr index = SymbolIndexExpr(
loopInd[std::min<unsigned>(i, loopInd.size() - 1)]);
outputAccessFct.emplace_back(index);
}
}
Expand Down
13 changes: 11 additions & 2 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,10 +1023,19 @@ def get_test_models():
},
# ==OP== GatherND
# ==MIN== 11
"test_gathernd_example_int32_cpu": {STATIC_SHAPE: {}, CONSTANT_INPUT: {-1}},
"test_gathernd_example_float32_cpu": {STATIC_SHAPE: {}, CONSTANT_INPUT: {-1}},
"test_gathernd_example_int32_cpu": {
STATIC_SHAPE: {},
DYNAMIC_SHAPE: {-1: {-1}},
CONSTANT_INPUT: {-1},
},
"test_gathernd_example_float32_cpu": {
STATIC_SHAPE: {},
DYNAMIC_SHAPE: {-1: {-1}},
CONSTANT_INPUT: {-1},
},
"test_gathernd_example_int32_batch_dim1_cpu": {
STATIC_SHAPE: {},
DYNAMIC_SHAPE: {-1: {-1}},
CONSTANT_INPUT: {-1},
},
# ==OP== Gemm
Expand Down
58 changes: 58 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/Tensor/GatherND.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,61 @@ func.func @test_gather_nd_2(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1x2xi64>
// CHECK: [[RES:%.+]] = memref.reinterpret_cast [[RES_BUFFER]] to offset: [0], sizes: [2, 1, 2], strides: [2, 2, 1] : memref<4xf32> to memref<2x1x2xf32>
// CHECK: return [[RES]] : memref<2x1x2xf32>
}

// -----

// COM: Test GatherND with dynamic shape
func.func @test_gather_nd_with_dynamic_shape_int(%arg0 : tensor<2x2xi32>, %arg1 : tensor<?x2xi64>) -> tensor<?xi32> {
%0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xi32>, tensor<?x2xi64>) -> tensor<?xi32>
"func.return"(%0) : (tensor<?xi32>) -> ()
// mlir2FileCheck.py
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @test_gather_nd_with_dynamic_shape_int
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x2xi32>, [[PARAM_1_:%.+]]: memref<?x2xi64>) -> memref<?xi32> {
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_1_]], [[CST_0_]] : memref<?x2xi64>
// CHECK-DAG: [[CST_2_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_2_3_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_2_4_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_dim_5_:%.+]] = memref.dim [[PARAM_1_]], [[CST_0_1_]] : memref<?x2xi64>
// CHECK-DAG: [[CST_2_5_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_2_6_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index
// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]]([[VAR_dim_5_]])
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1, [[VAR_dim_5_]], 2], strides: {{.}}[[VAR_0_]], 2, 1] : memref<?x2xi64> to memref<1x?x2xi64>
// CHECK-DAG: [[CST_1_2_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[VAR_reinterpret_cast_10_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x2xi32> to memref<1x2x2xi32>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) : memref<?xi32>
// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_1_3_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() : memref<index>
// CHECK: krnl.store [[CST_0_2_]], [[RES_1_]][] : memref<index>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_5_]])){
// CHECK-DAG: [[VAR_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[CST_0_4_:%.+]] = arith.constant 0 : index
// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[CST_0_4_]]{{.}} : memref<1x?x2xi64>
// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[LOAD_VAR_reinterpret_cast_MEM_]] : i64 to index
// CHECK-DAG: [[CST_1_4_:%.+]] = arith.constant 1 : index
// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[CST_1_4_]]{{.}} : memref<1x?x2xi64>
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[LOAD_VAR_reinterpret_cast_MEM_1_]] : i64 to index
// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_10_MEM_:%.+]] = krnl.load [[VAR_reinterpret_cast_10_]]{{.}}[[VAR_2_]]#0, [[VAR_4_]], [[VAR_6_]]{{.}} : memref<1x2x2xi32>
// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref<index>
// CHECK: krnl.store [[LOAD_VAR_reinterpret_cast_10_MEM_]], [[RES_]]{{.}}[[LOAD_RES_1_MEM_]]{{.}} : memref<?xi32>
// CHECK: [[VAR_9_:%.+]] = arith.addi [[LOAD_RES_1_MEM_]], [[CST_1_3_]] : index
// CHECK: krnl.store [[VAR_9_]], [[RES_1_]][] : memref<index>
// CHECK: }
// CHECK-DAG: [[CST_1_5_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_reinterpret_cast_15_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: [0], sizes: {{.}}[[VAR_dim_]]{{.}}, strides: [1] : memref<?xi32> to memref<?xi32>
// CHECK: return [[VAR_reinterpret_cast_15_]] : memref<?xi32>
}

29 changes: 29 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/Tensor/ScatterND.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,32 @@ func.func @test_scatter_nd1(%arg0: tensor<4x4x4xf32>, %arg1: tensor<2x1xi64>, %a
// CHECK: return [[RES]] : memref<4x4x4xf32>
}

// -----

// COM: Test GatherND with dynamic shape
func.func @test_scatter_nd_with_dynamic_indices(%arg0: tensor<2x1xi64>, %arg1: tensor<?x2xi64>, %arg2: tensor<2xi64>) -> tensor<2x1xi64> {
%0 = "onnx.ScatterND"(%arg0, %arg1, %arg2) {reduction = "none"} : (tensor<2x1xi64>, tensor<?x2xi64>, tensor<2xi64>) -> tensor<2x1xi64>
return %0 : tensor<2x1xi64>
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_scatter_nd_with_dynamic_indices
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x1xi64>, [[PARAM_1_:%.+]]: memref<?x2xi64>, [[PARAM_2_:%.+]]: memref<2xi64>) -> memref<2x1xi64> {
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x1xi64>
// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : i64
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK: "krnl.memcpy"([[RES_]], [[PARAM_0_]], [[CST_2_1_]], [[CST_0_]], [[CST_0_]]) : (memref<2x1xi64>, memref<2x1xi64>, i64, index, index) -> ()
// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_2_2_:%.+]] = arith.constant 2 : index
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 2){
// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index
// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]], [[CST_0_2_]]{{.}} : memref<?x2xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_1_]]{{.}} : memref<2xi64>
// CHECK: krnl.store [[LOAD_PARAM_2_MEM_]], [[RES_]]{{.}}[[VAR_3_]], [[VAR_1_]]{{.}} : memref<2x1xi64>
// CHECK: }
// CHECK: return [[RES_]] : memref<2x1xi64>
// CHECK: }
}

0 comments on commit ee36a16

Please sign in to comment.