Skip to content

Commit

Permalink
[mlir][tensor] Enhance SimplifyUnPackToCollapseShape for unit dim cas…
Browse files Browse the repository at this point in the history
…es. (llvm#79262)
  • Loading branch information
hanhanW authored Jan 25, 2024
1 parent ca0e241 commit ad3cda7
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 28 deletions.
69 changes: 41 additions & 28 deletions mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
}

/// Returns success() if there is only 1 dimension size in non-packed domain
/// being greater than 1 and packing only happens on the dimension.
/// Note: this method should only be used by pack/unpack to reshape conversion.
/// It assumes that non-unit inner tile size must be used by the non-unit
/// dimension.
static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> innerPackTileSize) {
if (getNumGtOneDims(srcShape) > 1) {
return rewriter.notifyMatchFailure(
op, "expects non-packed domain to have at most one non-unit dims");
}
// Non-unit inner tile size must be used by the non-unit dimension. If not, it
// will faill on getting reassociation maps.
if (getNumGtOneDims(innerPackTileSize) > 1) {
return rewriter.notifyMatchFailure(
op, "expects at most one non-unit inner tiles");
}
return success();
}

/// Packing one-dimensional tensor can be expressed as an expand shape op.
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
Expand Down Expand Up @@ -59,40 +80,18 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return success();
}

/// Returns success() if there is only 1 dimension size in source being
/// greater than 1 and packing only happens on the dimension. It assumes that
/// the pack op does not have padding value.
LogicalResult isPack1DSrc(RewriterBase &rewriter, PackOp packOp) const {
assert(!packOp.getPaddingValue() &&
"expect the op does not have padding value.");
ArrayRef<int64_t> srcShape = packOp.getSourceType().getShape();
if (getNumGtOneDims(srcShape) > 1) {
return rewriter.notifyMatchFailure(
packOp, "expects source to have at most one non-unit dims");
}

// The pack op does not have padding value. Non-unit inner tile size must be
// be used by the non-unit dimension.
SmallVector<int64_t> innerTiles = packOp.getStaticTiles();
if (getNumGtOneDims(innerTiles) > 1) {
return rewriter.notifyMatchFailure(
packOp, "expects at most one non-unit inner tiles");
}

return success();
}

LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");

RankedTensorType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPack1DSrc(rewriter, packOp))) {
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles()))) {
return failure();
}

RankedTensorType sourceType = packOp.getSourceType();
RankedTensorType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
Expand All @@ -117,8 +116,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
operand, reassociation);
}

LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
/// Returns success() if it is unpacking on the innermost dimension.
LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
UnPackOp unpackOp) const {
auto outerDimsPerm = unpackOp.getOuterDimsPerm();
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
return rewriter.notifyMatchFailure(
Expand All @@ -134,9 +134,22 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
return rewriter.notifyMatchFailure(
unpackOp, "expects unpacking at the innermost dimension");
unpackOp, "expects unpacking on the innermost dimension");
}

return success();
}

LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
RankedTensorType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles()))) {
return failure();
}

RankedTensorType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,54 @@ func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor
%0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
return %0 : tensor<256x5xf32>
}

// -----

// CHECK-LABEL: func.func @unpack_1x32x1x1_to_1x32
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
// CHECK: return %[[COLLAPSED]]
func.func @unpack_1x32x1x1_to_1x32(%arg0 : tensor<1x32x1x1xf32>) -> tensor<1x32xf32> {
%empty = tensor.empty() : tensor<1x32xf32>
%unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %empty
: tensor<1x32x1x1xf32> -> tensor<1x32xf32>
return %unpack : tensor<1x32xf32>
}

// -----

// CHECK-LABEL: func.func @unpack_1x2x1x16_to_1x32
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
// CHECK: return %[[COLLAPSED]]
func.func @unpack_1x2x1x16_to_1x32(%arg0 : tensor<1x2x1x16xf32>) -> tensor<1x32xf32> {
%empty = tensor.empty() : tensor<1x32xf32>
%unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [1, 16] into %empty
: tensor<1x2x1x16xf32> -> tensor<1x32xf32>
return %unpack : tensor<1x32xf32>
}

// -----

// CHECK-LABEL: func.func @unpack_16x1x2x1_to_32x1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK: return %[[COLLAPSED]]
func.func @unpack_16x1x2x1_to_32x1(%arg0 : tensor<1x16x2x1xf32>) -> tensor<32x1xf32> {
%empty = tensor.empty() : tensor<32x1xf32>
%unpack = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 1] into %empty
: tensor<1x16x2x1xf32> -> tensor<32x1xf32>
return %unpack : tensor<32x1xf32>
}

// -----

// CHECK-LABEL: func.func @unpack_16x1x1x2_to_32x1
// CHECK-NOT: tensor.collapse_shape
// CHECK: tensor.unpack
func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1xf32> {
%empty = tensor.empty() : tensor<32x1xf32>
%unpack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [1, 2] into %empty
: tensor<16x1x1x2xf32> -> tensor<32x1xf32>
return %unpack : tensor<32x1xf32>
}

0 comments on commit ad3cda7

Please sign in to comment.