Skip to content

Commit

Permalink
[mlir][tensor] Make getMixedPadImpl return static values when possibl…
Browse files Browse the repository at this point in the history
…e. (llvm#85016)

If low and high are constants (i.e., not attributes), users still prefer
attributes. Otherwise, there could be failures in type inference. A
failure is introduced by
llvm@60e562d,
see the drop_known_unit_constant_low_high test for more details.
  • Loading branch information
hanhanW authored Mar 13, 2024
1 parent 9a3000c commit bb82092
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
unsigned count = staticAttrs.size();
for (unsigned idx = 0; idx < count; ++idx) {
if (ShapedType::isDynamic(staticAttrs[idx]))
res.push_back(values[numDynamic++]);
res.push_back(getAsOpFoldResult(values[numDynamic++]));
else
res.push_back(builder.getI64IntegerAttr(staticAttrs[idx]));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
// CHECK-LABEL: func @generalize_pad_tensor_dynamic_shape(
// CHECK-SAME: %[[IN:.*]]: tensor<4x?x2x?xf32>,
// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
Expand All @@ -33,7 +32,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
// CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index
// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : tensor<4x?x?x?xf32>
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32>
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 0, %[[OFFSET]], 0] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
// CHECK: return %[[PADDED]] : tensor<4x?x?x?xf32>
// CHECK: }
func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> {
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1033,3 +1033,23 @@ func.func @do_not_drop_non_constant_padding(%arg0: tensor<1x1x3x1x1xf32>, %pad:
// CHECK-SLICES-LABEL: func @do_not_drop_non_constant_padding
// CHECK-SLICES: tensor.pad %{{.*}} low[0, 1, 0, %c0, 0] high[0, 0, 0, %c0, 2]
// CHECK-SLICES: } : tensor<1x1x3x1x1xf32> to tensor<1x2x3x1x3xf32>

// -----

func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> tensor<1x384x128xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %arg0 low[%c0, %c1, %c0] high[%c0, %c0, %c0] {
^bb0(%arg1: index, %arg2: index, %arg3: index):
tensor.yield %cst : f32
} : tensor<1x383x128xf32> to tensor<1x384x128xf32>
return %padded : tensor<1x384x128xf32>
}
// CHECK-LABEL: func @drop_known_unit_constant_low_high
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: {{\[}}[0, 1], [2]] : tensor<1x383x128xf32> into tensor<383x128xf32>
// CHECK: %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0] high[0, 0]
// CHECK: } : tensor<383x128xf32> to tensor<384x128xf32>
// CHECK: tensor.expand_shape %[[PADDED]]
// CHECK-SAME: {{\[}}[0, 1], [2]] : tensor<384x128xf32> into tensor<1x384x128xf32>
3 changes: 1 addition & 2 deletions mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
// CHECK-LABEL: func @generalize_pad_tensor_dynamic_shape(
// CHECK-SAME: %[[IN:.*]]: tensor<4x?x2x?xf32>,
// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
Expand All @@ -32,7 +31,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32>
// CHECK: %[[DIM1_1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
// CHECK: %[[DIM3_1:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1_1]], 2, %[[DIM3_1]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 0, %[[OFFSET]], 0] [4, %[[DIM1_1]], 2, %[[DIM3_1]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
// CHECK: return %[[PADDED]] : tensor<4x?x?x?xf32>
// CHECK: }
func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> {
Expand Down

0 comments on commit bb82092

Please sign in to comment.