Skip to content

Commit

Permalink
Reassemble 32-bit float complex into a 64-bit float for splats. (iree…
Browse files Browse the repository at this point in the history
…-org#13363)

Splats are not able to understand `complex<f32>`, so the real and
imaginary fields are extracted and put together as a f64 and passed
through the splat.

`Complex<f64>` is currently being rejected as not-implemented.

---------

Co-authored-by: Ben Vanik <[email protected]>
  • Loading branch information
bviyer and benvanik authored May 9, 2023
1 parent f1efd44 commit a97c63f
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 13 deletions.
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FLOW_Op<string mnemonic, list<Trait> traits = []> :
// Flow dialect types
//===----------------------------------------------------------------------===//

def FLOW_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat]>;
def FLOW_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat, AnyComplex]>;

def FLOW_Dim : TypeAlias<Index>;
def FLOW_ShapeDynamicDims : Variadic<FLOW_Dim>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class CommandBufferFillBufferOpConversion

// Record the original pattern length then extend it to a 32 bit integer.
auto originalPatternType = op.getPattern().getType();
auto patternBitWidth = originalPatternType.getIntOrFloatBitWidth();
unsigned patternBitWidth = IREE::Util::getTypeBitWidth(originalPatternType);
// The pattern length (in bytes) will be used at runtime to issue the fill
// command. While the pattern itself will be stored in a 32 bit integer,
// the fill operation will use this length to slice a potentially smaller
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def HAL_HostSizeAttr : Util_IndexAttrBase<"size_t">;

def HAL_TimelineValue : TypeAlias<I64>;

def HAL_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat]>;
def HAL_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat, AnyComplex]>;
def HAL_FillPatternType : AnyTypeOf<[I8, I16, I32]>;

def HAL_GlobalRefAttr : Util_AliasedSymbolRefAttr;
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class Stream_Op<string mnemonic, list<Trait> traits = []> :
// Stream dialect types
//===----------------------------------------------------------------------===//

def Stream_PrimitiveType : AnyTypeOf<[Index, AnyInteger, AnyFloat]>;
def Stream_PrimitiveType : AnyTypeOf<[Index, AnyInteger, AnyFloat, AnyComplex]>;
def Stream_FillPatternType : AnyTypeOf<[I8, I16, I32]>;

def Stream_Offset : TypeAlias<Index>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ iree_cc_library(
MLIRAffineDialect
MLIRAnalysis
MLIRArithDialect
MLIRComplexDialect
MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -66,7 +67,7 @@ static Type alignElementType(Type originalType) {
// Align the element type to a power of two byte size.
auto alignedBitWidth =
IREE::Util::getRoundedElementByteWidth(elementType) * 8;
if (elementType.getIntOrFloatBitWidth() == alignedBitWidth) {
if (IREE::Util::getTypeBitWidth(elementType) == alignedBitWidth) {
// Already aligned.
return originalType;
}
Expand Down Expand Up @@ -335,18 +336,39 @@ static Value canonicalizeFillPattern(Value pattern, PatternRewriter &rewriter) {

// Get floats into integer form.
auto patternType = pattern.getType();
unsigned bitWidth = patternType.getIntOrFloatBitWidth();
unsigned elementBitWidth = IREE::Util::getTypeBitWidth(patternType);
elementBitWidth =
(isa<ComplexType>(patternType) ? elementBitWidth / 2 : elementBitWidth);
if (patternType.isa<FloatType>()) {
pattern = rewriter.createOrFold<arith::BitcastOp>(
loc, rewriter.getIntegerType(bitWidth), pattern);
loc, rewriter.getIntegerType(elementBitWidth), pattern);
}

if (isa<ComplexType>(patternType)) {
int64_t complexBitWidth = elementBitWidth;
Type bwElemType = rewriter.getIntegerType(elementBitWidth);
Type bwType = rewriter.getIntegerType(elementBitWidth * 2);
Value shiftAmount = rewriter.create<arith::ConstantOp>(
loc, bwType, rewriter.getIntegerAttr(bwType, complexBitWidth));

Value real = rewriter.create<mlir::complex::ReOp>(loc, pattern);
Value realInt = rewriter.create<arith::BitcastOp>(loc, bwElemType, real);
Value imag = rewriter.create<mlir::complex::ImOp>(loc, pattern);
Value imagInt = rewriter.create<arith::BitcastOp>(loc, bwElemType, imag);
realInt = rewriter.create<arith::IndexCastOp>(loc, bwType, realInt);
imagInt = rewriter.create<arith::IndexCastOp>(loc, bwType, imagInt);
Value shiftReal =
rewriter.create<arith::ShLIOp>(loc, bwType, realInt, shiftAmount);
Value orImag = rewriter.create<arith::OrIOp>(loc, shiftReal, imagInt);
return orImag;
}

// HACK: extend i1 to i8. This is really not something we should be doing here
// in optimized programs as this is a super shady operation.
if (patternType.isInteger(1)) {
return rewriter.createOrFold<arith::ExtUIOp>(loc, rewriter.getI8Type(),
pattern);
} else if ((bitWidth % 8) != 0) {
} else if ((elementBitWidth % 8) != 0) {
// We'd need some policy to determine how to handle non-byte-aligned widths.
return {};
}
Expand All @@ -373,7 +395,9 @@ struct EncodeTensorSplatOp
if (!pattern) {
return rewriter.notifyMatchFailure(
op, "unsupported pattern width; encoding policy required");
} else if (pattern.getType().getIntOrFloatBitWidth() > 32) {
}
unsigned bitWidth = IREE::Util::getTypeBitWidth(pattern.getType());
if (bitWidth > 32) {
// We emulate 64-bit support with a stream.builtin.splat.i64.
rewriter.replaceOpWithNewOp<IREE::Stream::BuiltinSplatI64Op>(
op, op.getResult().getType(), pattern, op.getResultSize(),
Expand Down Expand Up @@ -478,7 +502,18 @@ struct EncodeTensorFillOp
if (!pattern) {
return rewriter.notifyMatchFailure(
op, "unsupported pattern width; encoding policy required");
} else if (pattern.getType().getIntOrFloatBitWidth() > 32) {
}
unsigned bitWidth = IREE::Util::getTypeBitWidth(pattern.getType());
if (bitWidth > 64) {
// This happens mostly when complex<f64> is used as a input type for
// splat. complex<type> is broken down into a 2xtype value with the real
// field occupying the sizeof(type) MSB bits and the imaginary field
// occupying the rest. At this moment, splats with size > 64 is not
// implemented so we error out here.
return rewriter.notifyMatchFailure(
op, "unsupported bitWidth greater than 64; encoding policy required");
}
if (bitWidth > 32) {
rewriter.replaceOpWithNewOp<IREE::Stream::BuiltinFillI64Op>(
op, op.getResult().getType(), op.getTarget(), op.getTargetSize(),
targetOffset, targetEnd, targetLength, pattern, op.getAffinityAttr());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ static void updateDispatchOp(IREE::Stream::CmdDispatchOp dispatchOp,

// i1-i31 -> i32 and i33-i63 -> i64
// TODO(benvanik): don't extend here but instead pack as we can fit 4 i8's
// into a single i32 and save 4x our push constant capacity.
unsigned bitWidth = operand.getType().getIntOrFloatBitWidth();
// into a single i32 and save 4x our push constant capacity
unsigned bitWidth = IREE::Util::getTypeBitWidth(operand.getType());
if (bitWidth < 31) {
operand = builder.createOrFold<arith::ExtUIOp>(loc, builder.getI32Type(),
operand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ func.func @denseTensorSplatI64(%arg0: i64, %arg1: index, %arg2: index) -> !strea

// -----

// CHECK-LABEL: @denseTensorSplatComplexF32
func.func @denseTensorSplatComplexF32(%arg0: !stream.resource<*>) -> (!stream.resource<*>) {
%cst = complex.constant [3.000000e+00 : f32, 1.000000e+01 : f32] : complex<f32>
%0 = stream.tensor.sizeof tensor<6xcomplex<f32>> : index
// CHECK: %[[I64NUMBER:.+]] = arith.constant 4629700418029486080
// CHECK: %[[SPLAT_RES:.+]] = stream.builtin.splat.i64 %[[I64NUMBER]]
%1 = stream.tensor.splat %cst : complex<f32> -> tensor<6xcomplex<f32>> in !stream.resource<*>{%0}
return %1 : !stream.resource<*>
}

// -----

// NOTE: clone likes to fold; the fills ensure it doesn't.

// CHECK-LABEL: @denseTensorClone
Expand Down Expand Up @@ -239,3 +251,4 @@ func.func @denseTensorStoreRank0(%arg0: !stream.resource<staging>, %arg1: index,
// CHECK: return %[[RET]]
return %0 : !stream.resource<staging>
}

9 changes: 9 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,15 @@ static inline int32_t getRoundedElementByteWidth(Type type) {
return llvm::PowerOf2Ceil(byteAligned);
}

// Returns the bit-width of the scalar type. If the type is complex, it returns
// the type of individual elements * 2 (1 for real and 1 for complex).
static int64_t getTypeBitWidth(Type type) {
if (auto complexType = type.dyn_cast<ComplexType>()) {
return 2 * complexType.getElementType().getIntOrFloatBitWidth();
}
return type.getIntOrFloatBitWidth();
}

} // namespace Util
} // namespace IREE
} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class IREEInput_AnyPtrOf<list<Type> types> :
string builderCall = "";
}

def IREEInput_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat]>;
def IREEInput_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat, AnyComplex]>;
def IREEInput_Tensor : TypeAlias<AnyRankedTensor>;

def IREEInput_AnyList : DialectType<
Expand Down

0 comments on commit a97c63f

Please sign in to comment.