Skip to content

Commit

Permalink
[mlir][MemRef] Move narrow type emulation common methods to MemRefUtils.
Browse files Browse the repository at this point in the history
It also unifies the computation of StridedLayoutAttr. If the stride is
static known value, we can just use it.

Differential Revision: https://reviews.llvm.org/D155017
  • Loading branch information
hanhanW committed Jul 13, 2023
1 parent a48f32d commit 8fc433f
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 215 deletions.
33 changes: 33 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#ifndef MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
#define MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H

#include "mlir/Dialect/MemRef/IR/MemRef.h"

namespace mlir {

class MemRefType;
Expand All @@ -26,6 +28,37 @@ namespace memref {
/// contiguous chunk of memory.
bool isStaticShapeAndContiguousRowMajor(MemRefType type);

/// Returns the flattened 1-D memref and linearized offset for narrow type
/// emulation.
///
/// The emulation only works on 1D memref types. To make this work on N-D
/// memref, we need to linearize the offset.
///
/// For example, to emulate i4 to i8, the following op:
///
/// %0 = memref.load %arg0[%v0, %v1] :
/// memref<?x?xi4, strided<[?, ?], offset: ?>>
///
/// can be replaced with
///
/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
///
/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
/// %linearized_size = %size0 * %size1
/// %scaled_linear_offset = %linearized_offset / 8 * 4
/// %scaled_base_offset = %offset / 8 * 4
///
/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
/// sizes = [%linearized_size], strides = [%stride#1]
///
/// %new_load = memref.load %linearized[%scaled_linear_offset] :
/// memref<?xi8, strided<[?], offset: ?>>
std::pair<Value, Value>
getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
int dstBits, SmallVector<Value> indices,
memref::ExtractStridedMetadataOp stridedMetadata,
OpBuilder &builder);

} // namespace memref
} // namespace mlir

Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
Expand Down
106 changes: 8 additions & 98 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
Expand All @@ -27,102 +28,6 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//

/// The emulation only works on 1D memref types.
/// To make this work on N-D memref, we need to linearize the offset.
///
/// For example, to emulate i4 to i8, the following op:
///
/// %0 = memref.load %arg0[%v0, %v1] :
/// memref<?x?xi4, strided<[?, ?], offset: ?>>
///
/// can be replaced with
///
/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
///
/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
/// %linearized_size = %size0 * %size1
/// %scaled_linear_offset = %linearized_offset / 8 * 4
/// %scaled_base_offset = %offset / 8 * 4
///
/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
/// sizes = [%linearized_size], strides = [%stride#1]
///
/// %new_load = memref.load %linearized[%scaled_linear_offset] :
/// memref<?xi8, strided<[?], offset: ?>>

static Value
linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits,
int dstBits, SmallVector<Value> indices,
memref::ExtractStridedMetadataOp stridedMetadata,
OpBuilder &builder) {
auto srcElementType = sourceType.getElementType();
unsigned sourceRank = indices.size();

Value baseBuffer = stridedMetadata.getBaseBuffer();
SmallVector<Value> baseSizes = stridedMetadata.getSizes();
SmallVector<Value> baseStrides = stridedMetadata.getStrides();
Value baseOffset = stridedMetadata.getOffset();
assert(indices.size() == baseStrides.size());

// Create the affine symbols and values for linearization.
SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
symbols[0] = builder.getAffineSymbolExpr(0);
AffineExpr addMulMap = symbols.front();
AffineExpr mulMap = symbols.front();

SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
offsetValues[0] = builder.getIndexAttr(0);
SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
sizeValues[0] = builder.getIndexAttr(1);

for (unsigned i = 0; i < sourceRank; ++i) {
unsigned offsetIdx = 2 * i + 1;
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
offsetValues[offsetIdx] = indices[i];
offsetValues[offsetIdx + 1] = baseStrides[i];

unsigned sizeIdx = i + 1;
mulMap = mulMap * symbols[sizeIdx];
sizeValues[sizeIdx] = baseSizes[i];
}

// Adjust linearizedOffset by the scale factor (dstBits / srcBits).
OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
offsetValues.back() = scaler;

OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
builder, loc, scaledAddMulMap, offsetValues);
OpFoldResult linearizedSize =
affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);

// Adjust baseOffset by the scale factor (dstBits / srcBits).
AffineExpr s0, s1;
bindSymbols(builder.getContext(), s0, s1);
OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
builder, loc, s0.floorDiv(s1), {baseOffset, scaler});

// Flatten n-D MemRef to 1-D MemRef.
auto layoutAttr = StridedLayoutAttr::get(
sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic});
int64_t staticShape = sourceType.hasStaticShape()
? sourceType.getNumElements()
: ShapedType::kDynamic;
auto flattenMemrefType = MemRefType::get(
staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());

auto reinterpret = builder.create<memref::ReinterpretCastOp>(
loc, flattenMemrefType, baseBuffer,
getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
baseStrides.back());

return builder.create<memref::LoadOp>(
loc, srcElementType, reinterpret.getResult(),
getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset));
}

/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
Expand Down Expand Up @@ -239,8 +144,13 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {

lastIdx = stridedMetadata.getOffset();
} else {
newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices,
stridedMetadata, rewriter);
auto [reinterpret, linearizedOffset] =
memref::getLinearizeMemRefAndOffset(loc, sourceType, srcBits, dstBits,
adaptor.getIndices(),
stridedMetadata, rewriter);

newLoad = rewriter.create<memref::LoadOp>(loc, srcElementType,
reinterpret, linearizedOffset);

lastIdx = adaptor.getIndices().back();
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ add_mlir_dialect_library(MLIRMemRefUtils

LINK_LIBS PUBLIC
MLIRIR
MLIRAffineDialect
MLIRArithUtils
)
77 changes: 77 additions & 0 deletions mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"

namespace mlir {
Expand Down Expand Up @@ -44,5 +46,80 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
return curDim < 0;
}

std::pair<Value, Value>
getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
int dstBits, SmallVector<Value> indices,
memref::ExtractStridedMetadataOp stridedMetadata,
OpBuilder &builder) {
auto srcElementType = sourceType.getElementType();
unsigned sourceRank = indices.size();

Value baseBuffer = stridedMetadata.getBaseBuffer();
SmallVector<Value> baseSizes = stridedMetadata.getSizes();
SmallVector<Value> baseStrides = stridedMetadata.getStrides();
Value baseOffset = stridedMetadata.getOffset();
assert(indices.size() == baseStrides.size());

// Create the affine symbols and values for linearization.
SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
symbols[0] = builder.getAffineSymbolExpr(0);
AffineExpr addMulMap = symbols.front();
AffineExpr mulMap = symbols.front();

SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
offsetValues[0] = builder.getIndexAttr(0);
SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
sizeValues[0] = builder.getIndexAttr(1);

for (unsigned i = 0; i < sourceRank; ++i) {
unsigned offsetIdx = 2 * i + 1;
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
offsetValues[offsetIdx] = indices[i];
offsetValues[offsetIdx + 1] = baseStrides[i];

unsigned sizeIdx = i + 1;
mulMap = mulMap * symbols[sizeIdx];
sizeValues[sizeIdx] = baseSizes[i];
}

// Adjust linearizedOffset by the scale factor (dstBits / srcBits).
OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
offsetValues.back() = scaler;

OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
builder, loc, scaledAddMulMap, offsetValues);
OpFoldResult linearizedSize =
affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);

// Adjust baseOffset by the scale factor (dstBits / srcBits).
AffineExpr s0, s1;
bindSymbols(builder.getContext(), s0, s1);
OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
builder, loc, s0.floorDiv(s1), {baseOffset, scaler});

// Flatten n-D MemRef to 1-D MemRef.
std::optional<int64_t> stride =
getConstantIntValue(stridedMetadata.getConstifiedMixedStrides().back());
auto layoutAttr =
StridedLayoutAttr::get(sourceType.getContext(), ShapedType::kDynamic,
{stride ? stride.value() : ShapedType::kDynamic});
int64_t staticShape = sourceType.hasStaticShape()
? sourceType.getNumElements()
: ShapedType::kDynamic;
auto flattenMemrefType = MemRefType::get(
staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());

auto reinterpret = builder.create<memref::ReinterpretCastOp>(
loc, flattenMemrefType, baseBuffer,
getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
baseStrides.back());

return std::make_pair(reinterpret, getValueOrCreateConstantIndexOp(
builder, loc, linearizedOffset));
}

} // namespace memref
} // namespace mlir
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
MLIRIR
MLIRLinalgDialect
MLIRMemRefDialect
MLIRMemRefUtils
MLIRSCFDialect
MLIRSideEffectInterfaces
MLIRTensorDialect
Expand Down
Loading

0 comments on commit 8fc433f

Please sign in to comment.