Skip to content

Commit

Permalink
[Winograd] Allow for specifying different input tile dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
Max191 authored and hanhanW committed May 21, 2024
1 parent cc1449b commit 575b4cb
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 54 deletions.
53 changes: 49 additions & 4 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include <cstdint>

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
Expand All @@ -21,6 +25,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -946,10 +951,33 @@ LogicalResult UnPackOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// Winograd op utilities
//===----------------------------------------------------------------------===//

template <typename WinogradOp>
static SmallVector<int64_t> getNonInputTileDims(WinogradOp op) {
static_assert(llvm::is_one_of<WinogradOp, WinogradInputTransformOp,
WinogradFilterTransformOp,
WinogradOutputTransformOp>::value,
"applies to only winograd transform operations");
SetVector<int64_t> inputTileDims(op.getInputTileDimensions().begin(),
op.getInputTileDimensions().end());
SmallVector<int64_t> dims = llvm::to_vector(
llvm::seq<int64_t>(op.getTransformedOperandType().getRank()));
SetVector<int64_t> dimSet(dims.begin(), dims.end());
dimSet.set_subtract(inputTileDims);
return dimSet.takeVector();
}

//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//

SmallVector<int64_t> WinogradInputTransformOp::getNonInputTileDims() {
return LinalgExt::getNonInputTileDims(*this);
}

LogicalResult WinogradInputTransformOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
Expand Down Expand Up @@ -1026,7 +1054,10 @@ LogicalResult WinogradInputTransformOp::verify() {
if (isNchw()) {
permute<Permutation::TTNCHW_TO_TTNHWC>(expectedOutputShape);
}
ArrayRef<int64_t> outputShape = outputType.getShape();
SmallVector<int64_t> outputShape(outputType.getShape());
SmallVector<int64_t> perm(getInputTileDimensions());
perm.append(getNonInputTileDims());
applyPermutationToVector(outputShape, perm);
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
Expand All @@ -1048,6 +1079,10 @@ LogicalResult WinogradInputTransformOp::reifyResultShapes(
// WinogradFilterTransformOp
//===----------------------------------------------------------------------===//

SmallVector<int64_t> WinogradFilterTransformOp::getNonInputTileDims() {
return LinalgExt::getNonInputTileDims(*this);
}

LogicalResult WinogradFilterTransformOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
Expand Down Expand Up @@ -1119,7 +1154,10 @@ LogicalResult WinogradFilterTransformOp::verify() {
if (isFchw()) {
permute<Permutation::TTFC_TO_TTCF>(expectedOutputShape);
}
ArrayRef<int64_t> outputShape = outputType.getShape();
SmallVector<int64_t> outputShape(outputType.getShape());
SmallVector<int64_t> perm(getInputTileDimensions());
perm.append(getNonInputTileDims());
applyPermutationToVector(outputShape, perm);
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
Expand All @@ -1141,6 +1179,10 @@ LogicalResult WinogradFilterTransformOp::reifyResultShapes(
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//

SmallVector<int64_t> WinogradOutputTransformOp::getNonInputTileDims() {
return LinalgExt::getNonInputTileDims(*this);
}

LogicalResult WinogradOutputTransformOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
Expand Down Expand Up @@ -1177,7 +1219,6 @@ LogicalResult WinogradOutputTransformOp::verify() {
}
return success();
}
ArrayRef<int64_t> outputShape = outputType.getShape();
if (outputType.getElementType() != inputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
Expand All @@ -1197,6 +1238,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
"expect image dimensions to be either [1, 2] or [2, 3]");
}
SmallVector<int64_t> inputShape(inputType.getShape());
SmallVector<int64_t> perm(getInputTileDimensions());
perm.append(getNonInputTileDims());
applyPermutationToVector(inputShape, perm);
if (isNchw()) {
permute<Permutation::TTNHWC_TO_TTNCHW>(inputShape);
}
Expand All @@ -1214,7 +1258,8 @@ LogicalResult WinogradOutputTransformOp::verify() {
expectedOutputShape[outputIndex] = getOutputTileSize() * inputShape[i];
}
}
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
if (failed(
verifyCompatibleShape(expectedOutputShape, outputType.getShape()))) {
return op->emitOpError("incompatible output shape");
}
return success();
Expand Down
27 changes: 21 additions & 6 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -986,13 +986,15 @@ def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_tr
Variadic<AnyShaped>:$outputs,
I64Attr:$output_tile_size,
I64Attr:$kernel_size,
DenseI64ArrayAttr:$image_dimensions
DenseI64ArrayAttr:$image_dimensions,
DenseI64ArrayAttr:$input_tile_dimensions
);

let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)>
CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions,
CArg<"ArrayRef<int64_t>", "{0, 1}">:$input_tile_dimensions)>
];

let results = (outs Variadic<AnyRankedTensor>:$result);
Expand All @@ -1002,6 +1004,7 @@ def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_tr
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`image_dimensions` `(` $image_dimensions `)`
`input_tile_dimensions` `(` $input_tile_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
Expand Down Expand Up @@ -1061,6 +1064,8 @@ def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_tr
int getChannelDim() {
return isNhwc() ? 3 : 1;
}
// Utility for mapping non input tile dims to the actual result dims
SmallVector<int64_t> getNonInputTileDims();
int64_t getIterationDomainRank() {
return getOutputRank();
}
Expand Down Expand Up @@ -1098,13 +1103,15 @@ def IREELinalgExt_WinogradFilterTransformOp : IREELinalgExt_Op<"winograd.filter_
Variadic<AnyShaped>:$outputs,
I64Attr:$output_tile_size,
I64Attr:$kernel_size,
DenseI64ArrayAttr:$kernel_dimensions
DenseI64ArrayAttr:$kernel_dimensions,
DenseI64ArrayAttr:$input_tile_dimensions
);

let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
CArg<"ArrayRef<int64_t>", "{0, 1}">:$kernel_dimensions)>
CArg<"ArrayRef<int64_t>", "{0, 1}">:$kernel_dimensions,
CArg<"ArrayRef<int64_t>", "{0, 1}">:$input_tile_dimensions)>
];

let results = (outs Variadic<AnyRankedTensor>:$result);
Expand All @@ -1114,6 +1121,7 @@ def IREELinalgExt_WinogradFilterTransformOp : IREELinalgExt_Op<"winograd.filter_
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`kernel_dimensions` `(` $kernel_dimensions `)`
`input_tile_dimensions` `(` $input_tile_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
Expand Down Expand Up @@ -1178,6 +1186,8 @@ def IREELinalgExt_WinogradFilterTransformOp : IREELinalgExt_Op<"winograd.filter_
int getFilterDim() {
return isHwcf() ? 3 : 0;
}
// Utility for mapping non input tile dims to the actual result dims
SmallVector<int64_t> getNonInputTileDims();
int64_t getIterationDomainRank() {
return getOutputRank();
}
Expand Down Expand Up @@ -1218,13 +1228,15 @@ def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_
Variadic<AnyShaped>:$outputs,
I64Attr:$output_tile_size,
I64Attr:$kernel_size,
DenseI64ArrayAttr:$image_dimensions
DenseI64ArrayAttr:$image_dimensions,
DenseI64ArrayAttr:$input_tile_dimensions
);

let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)>
CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions,
CArg<"ArrayRef<int64_t>", "{0, 1}">:$input_tile_dimensions)>
];

let results = (outs Variadic<AnyRankedTensor>:$result);
Expand All @@ -1234,6 +1246,7 @@ def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`image_dimensions` `(` $image_dimensions `)`
`input_tile_dimensions` `(` $input_tile_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
Expand Down Expand Up @@ -1290,6 +1303,8 @@ def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_
int64_t getOutputRank() {
return getOutputType().getRank();
}
// Utility for mapping non input tile dims to the actual result dims
SmallVector<int64_t> getNonInputTileDims();
int64_t getIterationDomainRank() {
return getInputRank();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ createExpand(Value tensor, Location loc, PatternRewriter &rewriter,
reassociations);
}

static Value createTranspose(OpBuilder &builder, Value source,
SmallVector<int64_t> perm) {
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(builder, source.getLoc(), source);
applyPermutationToVector(mixedSizes, perm);
Type elemType = cast<RankedTensorType>(source.getType()).getElementType();
Value empty =
builder.create<tensor::EmptyOp>(source.getLoc(), mixedSizes, elemType)
.getResult();
return builder
.create<linalg::TransposeOp>(source.getLoc(), source, empty, perm)
->getResult(0);
}

namespace {

/// Convert conv2d to a sequence of ops that implement the
Expand Down Expand Up @@ -196,11 +210,12 @@ class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> {
const int64_t inputTileSize = outputTileSize + kernelSize - 1;

Location loc = convOp.getLoc();
const std::array<int64_t, 2> inputTileDimsKernel = {2, 3};
const std::array<int64_t, 2> hwcfKernelDims = {0, 1};
const std::array<int64_t, 2> fchwKernelDims = {2, 3};
SmallVector<int64_t> filterResultShape(4, inputTileSize);
filterResultShape[2] = isNchwFchw ? kernelShape[1] : kernelShape[2];
filterResultShape[3] = isNchwFchw ? kernelShape[0] : kernelShape[3];
filterResultShape[0] = isNchwFchw ? kernelShape[1] : kernelShape[2];
filterResultShape[1] = isNchwFchw ? kernelShape[0] : kernelShape[3];
Value kernelInit =
rewriter.create<tensor::EmptyOp>(loc, filterResultShape, inElemType);
const std::array<int64_t, 2> kernelDims =
Expand All @@ -209,15 +224,16 @@ class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> {
rewriter
.create<IREE::LinalgExt::WinogradFilterTransformOp>(
loc, kernelInit.getType(), ValueRange{kernel},
ValueRange{kernelInit}, outputTileSize, kernelSize, kernelDims)
ValueRange{kernelInit}, outputTileSize, kernelSize, kernelDims,
inputTileDimsKernel)
.getResults()[0];

// Add collapse shape
SmallVector<int64_t> collapsedFilterShape;
collapsedFilterShape.push_back(filterResultShape[0] * filterResultShape[1]);
collapsedFilterShape.push_back(filterResultShape[2]);
collapsedFilterShape.push_back(filterResultShape[3]);
SmallVector<ReassociationIndices> filterReassociations = {{0, 1}, {2}, {3}};
collapsedFilterShape.push_back(filterResultShape[0]);
collapsedFilterShape.push_back(filterResultShape[1]);
collapsedFilterShape.push_back(filterResultShape[2] * filterResultShape[3]);
SmallVector<ReassociationIndices> filterReassociations = {{0}, {1}, {2, 3}};
Value collapsedWinogradFilter =
createCollapse(winogradFilter, loc, rewriter, collapsedFilterShape,
filterReassociations);
Expand All @@ -237,19 +253,17 @@ class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> {
permute<IREE::LinalgExt::Permutation::NCHW_TO_NHWC>(inputShape);
}

const std::array<int64_t, 2> inputTileDimsImage = {4, 5};
const std::array<int64_t, 2> nhwcImageDims = {1, 2};
const std::array<int64_t, 2> nchwImageDims = {2, 3};
const size_t numImageDims = nhwcImageDims.size();
SmallVector<int64_t> resultShape(6, inputTileSize);
llvm::SmallSetVector<int64_t, 2> imageDimsSet(nhwcImageDims.begin(),
nhwcImageDims.end());
int outputIndex;
for (int i = 0; i < inputShape.size(); i++) {
outputIndex = i + numImageDims;
if (!imageDimsSet.contains(i)) {
resultShape[outputIndex] = inputShape[i];
resultShape[i] = inputShape[i];
} else {
resultShape[outputIndex] =
resultShape[i] =
std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize);
}
}
Expand All @@ -261,19 +275,27 @@ class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> {
rewriter
.create<IREE::LinalgExt::WinogradInputTransformOp>(
loc, inputTfInit.getType(), ValueRange{input},
ValueRange{inputTfInit}, outputTileSize, kernelSize, imageDims)
ValueRange{inputTfInit}, outputTileSize, kernelSize, imageDims,
inputTileDimsImage)
.getResults()[0];

// Add collapse shape
SmallVector<int64_t> collapsedShape = {
resultShape[0] * resultShape[1],
resultShape[2] * resultShape[3] * resultShape[4], resultShape[5]};
SmallVector<ReassociationIndices> reassociations = {{0, 1}, {2, 3, 4}, {5}};
resultShape[0] * resultShape[1] * resultShape[2], resultShape[3],
resultShape[4] * resultShape[5]};
SmallVector<ReassociationIndices> reassociations = {{0, 1, 2}, {3}, {4, 5}};
Value collapsedWinogradInput = createCollapse(
winogradInput, loc, rewriter, collapsedShape, reassociations);
SmallVector<int64_t> perm = {2, 0, 1};
Value permutedWinogradInput =
createTranspose(rewriter, collapsedWinogradInput, perm);
Value permutedWinogradFilter =
createTranspose(rewriter, collapsedWinogradFilter, perm);

// Add BatchMatmulOp
SmallVector<int64_t> bmmShape(collapsedShape.begin(), collapsedShape.end());
SmallVector<int64_t> bmmShape = {
resultShape[4] * resultShape[5],
resultShape[0] * resultShape[1] * resultShape[2], resultShape[3]};
SmallVector<int64_t> outputShape(outputType.getShape());
if (isNchwFchw) {
permute<IREE::LinalgExt::Permutation::NCHW_TO_NHWC>(outputShape);
Expand All @@ -288,25 +310,27 @@ class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> {
ValueRange{bmmInit});
auto bmmOp = rewriter.create<linalg::BatchMatmulOp>(
loc, bmmOutputType,
ValueRange({collapsedWinogradInput, collapsedWinogradFilter}),
ValueRange({permutedWinogradInput, permutedWinogradFilter}),
ValueRange({fillOp.result()}));
Value bmmResult = bmmOp.getResult(0);
SmallVector<int64_t> resultPerm = {1, 2, 0};
Value permutedBmmResult = createTranspose(rewriter, bmmResult, resultPerm);

// Add expand shape
SmallVector<int64_t> expandedShape = {resultShape[0], resultShape[1],
resultShape[2], resultShape[3],
resultShape[4], outputShape[3]};
reassociations = {{0, 1}, {2, 3, 4}, {5}};
Value expandedBmmResult =
createExpand(bmmResult, loc, rewriter, expandedShape, reassociations);
resultShape[2], outputShape[3],
resultShape[4], resultShape[5]};
reassociations = {{0, 1, 2}, {3}, {4, 5}};
Value expandedBmmResult = createExpand(permutedBmmResult, loc, rewriter,
expandedShape, reassociations);

// Convert back into original domain
SmallVector<int64_t> paddedResultShape(outputShape.size(), 0);
for (int i = 0; i < outputShape.size(); i++) {
if (!imageDimsSet.contains(i)) {
paddedResultShape[i] = outputShape[i];
} else {
paddedResultShape[i] = resultShape[i + numImageDims] * outputTileSize;
paddedResultShape[i] = resultShape[i] * outputTileSize;
}
}
if (isNchwFchw) {
Expand All @@ -318,7 +342,8 @@ class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> {
rewriter
.create<IREE::LinalgExt::WinogradOutputTransformOp>(
loc, outputTfInit.getType(), ValueRange{expandedBmmResult},
ValueRange{outputTfInit}, outputTileSize, kernelSize, imageDims)
ValueRange{outputTfInit}, outputTileSize, kernelSize, imageDims,
inputTileDimsImage)
.getResults()[0];

// Extract slice
Expand Down
Loading

0 comments on commit 575b4cb

Please sign in to comment.