Skip to content

Commit

Permalink
Pad Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhanW committed May 22, 2024
1 parent 4d03c70 commit bd3b916
Showing 1 changed file with 58 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -512,6 +514,54 @@ static void padContractionLikeOp(RewriterBase &rewriter,
offsets, sizes, strides);
}

// Query and output has shape of (BxNxd)
// Key and Value has shape of (BxLxd)
static void padAttentionOp(RewriterBase &rewriter,
IREE::LinalgExt::AttentionOp attentionOp) {
SmallVector<int64_t> attentionPadSizes = {1, 128};
Location loc = attentionOp.getLoc();
Value key = attentionOp.getKey();
Value value = attentionOp.getValue();
Value query = attentionOp.getQuery();
Value scale = attentionOp.getScale();
Value result = attentionOp.getOutput();
ArrayRef<int64_t> queryShape = attentionOp.getQueryType().getShape();
if (queryShape[1] % attentionPadSizes[1] == 0)
return;
auto getPadding = [](int64_t value, int64_t padTo) {
return llvm::divideCeil(value, padTo) * padTo - value;
};
// int qPaddingInt = getPadding(queryShape[1], attentionPadSizes[1]);
auto nPadding =
rewriter.getIndexAttr(getPadding(queryShape[1], attentionPadSizes[1]));
OpFoldResult zero = rewriter.getIndexAttr(0);
if (!isConstantIntValue(nPadding, 0)) {
// For NHWC, the m-padding is for W and k-padding is for C
query = getPaddedValue(rewriter, loc, query, {zero, nPadding, zero});
if (llvm::dyn_cast_or_null<tensor::EmptyOp>(result.getDefiningOp())) {
auto paddedType = llvm::cast<ShapedType>(query.getType());
result = rewriter.create<tensor::EmptyOp>(loc, paddedType.getShape(),
paddedType.getElementType());
} else {
result = getPaddedValue(rewriter, loc, result, {zero, nPadding, zero});
}
}
auto paddedAttentionOp =
mlir::clone(rewriter, attentionOp, {result.getType()},
ArrayRef<Value>{query, key, value, scale, result});
// Extract slice.
IntegerAttr one = rewriter.getI64IntegerAttr(1);
SmallVector<OpFoldResult> offsets(3, zero);
SmallVector<OpFoldResult> strides(3, one);
ArrayRef<int64_t> resultShape = attentionOp.getOutputType().getShape();
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(resultShape[0]),
rewriter.getIndexAttr(resultShape[1]),
rewriter.getIndexAttr(resultShape[2])};
Value extracted = rewriter.createOrFold<tensor::ExtractSliceOp>(
loc, paddedAttentionOp->getResults()[0], offsets, sizes, strides);
rewriter.replaceOp(attentionOp, extracted);
}

struct PadToIntrinsicsPass
: public impl::PadToIntrinsicsBase<PadToIntrinsicsPass> {
using Base::Base;
Expand Down Expand Up @@ -546,6 +596,10 @@ void PadToIntrinsicsPass::runOnOperation() {
targetContractOps.push_back(linalgOp);
}
});
SmallVector<IREE::LinalgExt::AttentionOp> targetAttentionOps;
funcOp.walk([&](IREE::LinalgExt::AttentionOp attentionOp) {
targetAttentionOps.push_back(attentionOp);
});

// Iterate through and pad ops in the worklists.
IRRewriter rewriter(context);
Expand All @@ -557,6 +611,10 @@ void PadToIntrinsicsPass::runOnOperation() {
rewriter.setInsertionPoint(contractOp);
padContractionLikeOp(rewriter, contractOp);
}
for (auto attentionOp : targetAttentionOps) {
rewriter.setInsertionPoint(attentionOp);
padAttentionOp(rewriter, attentionOp);
}
}

} // namespace mlir::iree_compiler::Preprocessing

0 comments on commit bd3b916

Please sign in to comment.