From ee36a1649685141ad6d2a285eac135999d44830b Mon Sep 17 00:00:00 2001 From: Yasushi Negishi Date: Thu, 26 Oct 2023 13:00:57 +0900 Subject: [PATCH] Fix onnx.GatherND and onnx.ScatterND issues with dynamic indices (#2550) * Support onnx.GatherND, onnx.ScatterND and onnx.Gather ops with dynamic indices. Signed-off-by: Yasushi Negishi --- src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp | 80 ++++++++----------- .../ONNXToKrnl/Tensor/ScatterND.cpp | 3 +- test/backend/inference_backend.py | 13 ++- .../onnx_to_krnl/Tensor/GatherND.mlir | 58 ++++++++++++++ .../onnx_to_krnl/Tensor/ScatterND.mlir | 29 +++++++ 5 files changed, 133 insertions(+), 50 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp index 2de30bbdc6..8ac952067b 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp @@ -64,68 +64,60 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { Value data = adaptor.getData(); Value indices = adaptor.getIndices(); int64_t b = adaptor.getBatchDims(); - auto indicesType = indices.getType().cast(); + DimsExpr dataDims, indicesDims; + create.krnlIE.getShapeAsDims(data, dataDims); + create.krnlIE.getShapeAsDims(indices, indicesDims); auto dataType = data.getType().cast(); - ArrayRef indicesShape = indicesType.getShape(); ArrayRef dataShape = dataType.getShape(); - int64_t dataRank = dataShape.size(); - int64_t indicesRank = indicesShape.size(); + int64_t dataRank = dataDims.size(); + int64_t indicesRank = indicesDims.size(); + auto indicesType = indices.getType().cast(); + ArrayRef indicesShape = indicesType.getShape(); int64_t indicesLastDim = indicesShape[indicesRank - 1]; + assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) && + "indices.shape[-1] must be in the range [1, dataRank - b]"); // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); assert(convertedType && convertedType.isa() && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); - ArrayRef outputShape = outputMemRefType.getShape(); - int64_t outputRank = outputShape.size(); - - // Ensure the operation constains are satisfied. - assert(dataRank >= 1 && "The rank of 'data' must be >= 1"); - assert(indicesRank >= 1 && "The rank of 'indices' must be >= 1"); - assert((outputRank == dataRank + indicesRank - indicesLastDim - 1 - b) && - "Incorrect outut rank"); - assert(b >= 0 && "batch_dim should not be negative"); - assert(b < std::min(dataRank, indicesRank) && - "batch_dims must be smaller than the min(dataRank, indicesRank)"); - assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) && - "indices.shape[-1] must be in the range [1, dataRank - b]"); + DimsExpr outputDims = shapeHelper.getOutputDims(); // Reshape 'indices' to the 3D shape: // [batchDimSize, indicesDimsSize, indices.shape[-1]]. - const int64_t batchDimsSize = std::accumulate(indicesShape.begin(), - indicesShape.begin() + b, 1, std::multiplies()); - const int64_t indicesDimsSize = std::accumulate(indicesShape.begin(), - indicesShape.end(), 1, std::multiplies()); - assert(batchDimsSize >= 0 && "batchDimsSize must be non-negative"); - assert(indicesDimsSize >= 0 && "indicesDimsSize must be non-negative"); - - LiteralIndexExpr BDS(batchDimsSize), - IDS(indicesDimsSize / (batchDimsSize * indicesLastDim)), - ILD(indicesLastDim); + LiteralIndexExpr oneIE(1); + IndexExpr batchDimsSize = oneIE; + for (int64_t i = 0; i < b; i++) + batchDimsSize = batchDimsSize * indicesDims[i]; + IndexExpr indicesDimsSize = oneIE; + for (int64_t i = b; i < indicesRank - 1; i++) + indicesDimsSize = indicesDimsSize * indicesDims[i]; + IndexExpr BDS(batchDimsSize), IDS(indicesDimsSize); + LiteralIndexExpr ILD(indicesLastDim); DimsExpr newIndicesShape = {BDS, IDS, ILD}; Value reshapedIndices = create.mem.reinterpretCast(indices, newIndicesShape); LLVM_DEBUG(llvm::dbgs() << "reshapedIndices: " << reshapedIndices << "\n"); // Reshape 'data' to shape [batchDimSize, data.shape[b:]] - DimsExpr newDataShape = {BDS}; + DimsExpr newDataDims = {BDS}; for (int64_t i = b; i < dataRank; ++i) { - assert(dataShape[i] != ShapedType::kDynamic && - "Cannot support data with dynamic dimensions"); - LiteralIndexExpr dataDim(dataShape[i]); - newDataShape.emplace_back(dataDim); + newDataDims.emplace_back(dataDims[i]); } - int64_t reshapedDataRank = newDataShape.size(); - Value reshapedData = create.mem.reinterpretCast(data, newDataShape); + int64_t reshapedDataRank = newDataDims.size(); + Value reshapedData = create.mem.reinterpretCast(data, newDataDims); LLVM_DEBUG(llvm::dbgs() << "reshapedData: " << reshapedData << "\n"); // Allocate a 1D output buffer. - const int64_t outputDimsSize = std::accumulate( - outputShape.begin(), outputShape.end(), 1, std::multiplies()); - Value outputDataBuffer = create.mem.alloc( - MemRefType::get({outputDimsSize}, outputMemRefType.getElementType())); - + IndexExpr outputDimsSize = oneIE; + for (uint64_t i = 0; i < outputDims.size(); i++) + outputDimsSize = outputDimsSize * outputDims[i]; + SmallVector outputIndexExpr = {outputDimsSize}; + int64_t dim = outputDimsSize.isLiteral() ? outputDimsSize.getLiteral() + : ShapedType::kDynamic; + Type outputType = dataType.getElementType(); + Value outputDataBuffer = + create.mem.alloc(MemRefType::get({dim}, outputType), outputIndexExpr); // Initialize the index used to store the result values. Value iZero = create.math.constantIndex(0); Value iOne = create.math.constantIndex(1); @@ -247,14 +239,8 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { }); // Finally reshape 'outputDataBuffer' to the shape of the output. - DimsExpr newOutputShape; - for (int64_t dim : outputShape) { - LiteralIndexExpr outputDim(dim); - newOutputShape.emplace_back(outputDim); - } - Value reshapedOutput = - create.mem.reinterpretCast(outputDataBuffer, newOutputShape); + create.mem.reinterpretCast(outputDataBuffer, outputDims); LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n"); rewriter.replaceOp(op, reshapedOutput); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp b/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp index dd82884f5c..64974d983e 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp @@ -97,7 +97,8 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern { IndexExpr index = NonAffineIndexExpr(indexVal); outputAccessFct.emplace_back(index); } else { - IndexExpr index = SymbolIndexExpr(loopInd[i]); + IndexExpr index = SymbolIndexExpr( + loopInd[std::min(i, loopInd.size() - 1)]); outputAccessFct.emplace_back(index); } } diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 93518b3796..04e735e38b 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -1023,10 +1023,19 @@ def get_test_models(): }, # ==OP== GatherND # ==MIN== 11 - "test_gathernd_example_int32_cpu": {STATIC_SHAPE: {}, CONSTANT_INPUT: {-1}}, - "test_gathernd_example_float32_cpu": {STATIC_SHAPE: {}, CONSTANT_INPUT: {-1}}, + "test_gathernd_example_int32_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, + "test_gathernd_example_float32_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, "test_gathernd_example_int32_batch_dim1_cpu": { STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== Gemm diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/GatherND.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/GatherND.mlir index 9fc0ed4f58..a170d7a654 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/GatherND.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/GatherND.mlir @@ -73,3 +73,61 @@ func.func @test_gather_nd_2(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1x2xi64> // CHECK: [[RES:%.+]] = memref.reinterpret_cast [[RES_BUFFER]] to offset: [0], sizes: [2, 1, 2], strides: [2, 2, 1] : memref<4xf32> to memref<2x1x2xf32> // CHECK: return [[RES]] : memref<2x1x2xf32> } + +// ----- + +// COM: Test GatherND with dynamic shape +func.func @test_gather_nd_with_dynamic_shape_int(%arg0 : tensor<2x2xi32>, %arg1 : tensor) -> tensor { + %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xi32>, tensor) -> tensor + "func.return"(%0) : (tensor) -> () +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @test_gather_nd_with_dynamic_shape_int +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x2xi32>, [[PARAM_1_:%.+]]: memref) -> memref { +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_1_]], [[CST_0_]] : memref +// CHECK-DAG: [[CST_2_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_2_3_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_2_4_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_dim_5_:%.+]] = memref.dim [[PARAM_1_]], [[CST_0_1_]] : memref +// CHECK-DAG: [[CST_2_5_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_2_6_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]]([[VAR_dim_5_]]) +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1, [[VAR_dim_5_]], 2], strides: {{.}}[[VAR_0_]], 2, 1] : memref to memref<1x?x2xi64> +// CHECK-DAG: [[CST_1_2_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[VAR_reinterpret_cast_10_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x2xi32> to memref<1x2x2xi32> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) : memref +// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_3_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() : memref +// CHECK: krnl.store [[CST_0_2_]], [[RES_1_]][] : memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_5_]])){ +// CHECK-DAG: [[VAR_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[CST_0_4_:%.+]] = arith.constant 0 : index +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[CST_0_4_]]{{.}} : memref<1x?x2xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[LOAD_VAR_reinterpret_cast_MEM_]] : i64 to index +// CHECK-DAG: [[CST_1_4_:%.+]] = arith.constant 1 : index +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[CST_1_4_]]{{.}} : memref<1x?x2xi64> +// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[LOAD_VAR_reinterpret_cast_MEM_1_]] : i64 to index +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_10_MEM_:%.+]] = krnl.load [[VAR_reinterpret_cast_10_]]{{.}}[[VAR_2_]]#0, [[VAR_4_]], [[VAR_6_]]{{.}} : memref<1x2x2xi32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK: krnl.store [[LOAD_VAR_reinterpret_cast_10_MEM_]], [[RES_]]{{.}}[[LOAD_RES_1_MEM_]]{{.}} : memref +// CHECK: [[VAR_9_:%.+]] = arith.addi [[LOAD_RES_1_MEM_]], [[CST_1_3_]] : index +// CHECK: krnl.store [[VAR_9_]], [[RES_1_]][] : memref +// CHECK: } +// CHECK-DAG: [[CST_1_5_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_reinterpret_cast_15_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: [0], sizes: {{.}}[[VAR_dim_]]{{.}}, strides: [1] : memref to memref +// CHECK: return [[VAR_reinterpret_cast_15_]] : memref +} + diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/ScatterND.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/ScatterND.mlir index c7e5c609f0..08139256ab 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/ScatterND.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/ScatterND.mlir @@ -24,3 +24,32 @@ func.func @test_scatter_nd1(%arg0: tensor<4x4x4xf32>, %arg1: tensor<2x1xi64>, %a // CHECK: return [[RES]] : memref<4x4x4xf32> } +// ----- + +// COM: Test GatherND with dynamic shape +func.func @test_scatter_nd_with_dynamic_indices(%arg0: tensor<2x1xi64>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x1xi64> { + %0 = "onnx.ScatterND"(%arg0, %arg1, %arg2) {reduction = "none"} : (tensor<2x1xi64>, tensor, tensor<2xi64>) -> tensor<2x1xi64> + return %0 : tensor<2x1xi64> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_scatter_nd_with_dynamic_indices +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x1xi64>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref<2xi64>) -> memref<2x1xi64> { +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x1xi64> +// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : i64 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK: "krnl.memcpy"([[RES_]], [[PARAM_0_]], [[CST_2_1_]], [[CST_0_]], [[CST_0_]]) : (memref<2x1xi64>, memref<2x1xi64>, i64, index, index) -> () +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_2_2_:%.+]] = arith.constant 2 : index +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 2){ +// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index +// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]], [[CST_0_2_]]{{.}} : memref +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index +// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_1_]]{{.}} : memref<2xi64> +// CHECK: krnl.store [[LOAD_PARAM_2_MEM_]], [[RES_]]{{.}}[[VAR_3_]], [[VAR_1_]]{{.}} : memref<2x1xi64> +// CHECK: } +// CHECK: return [[RES_]] : memref<2x1xi64> +// CHECK: } +}