From 66f45d039f86253313ae1352bcee2b45c5744e57 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 11 Dec 2024 01:46:48 -0800 Subject: [PATCH] [Mosaic GPU] Add WGMMA to the Mosaic GPU MLIR Dialect. The op API is still in flux so I'm leaving some of the verification code untested. PiperOrigin-RevId: 705020066 --- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 180 ++++++++++++++++++++++++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 86 ++++++++++- tests/mosaic/gpu_dialect_test.py | 92 ++++++++++++ 3 files changed, 353 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index c86450fbdf0c..b21f56327457 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -48,6 +48,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/Diagnostics.h" #include "tsl/platform/statusor.h" @@ -311,6 +313,184 @@ llvm::LogicalResult AsyncStoreOp::verify() { getSliceLengths(), getIndices().size()); } +namespace { +llvm::FailureOr GetWgmmaLayout(mlir::Location loc, + mlir::MemRefType type, + absl::string_view name, + SwizzlingMode swizzling_mode) { + auto error = [loc](auto... params) { + return emitError(loc, llvm::formatv(params...)); + }; + + auto [strides, offset] = mlir::getStridesAndOffset(type); + + WGMMALayout layout = WGMMALayout::RowMajor; + if (strides[3] == 1) { + layout = WGMMALayout::RowMajor; + } else if (strides[2] == 1) { + layout = WGMMALayout::ColumnMajor; + } else { + return error( + "At least one of the last two dimensions of `{0}` must have a " + "stride of 1, but they do not: stride(dim 2)={1}, stride(dim 3)={2}", + name, strides[2], strides[3]); + } + + auto shape = type.getShape(); + if (layout == WGMMALayout::RowMajor && strides[2] != shape[3]) { + return error( + "When `{0}` has row-major layout, the stride of dimension 2 (={1}) " + "must be equal to size of dimension 3 (={2})", + shape[3], strides[2], shape[3]); + } + + if (layout == WGMMALayout::ColumnMajor && strides[3] != shape[2]) { + return error( + "When `{0}` has column-major layout, the stride of dimension 3 (={1}) " + "must be equal to size of dimension 2 (={2})", + shape[2], strides[3], shape[2]); + } + + if (strides[1] != shape[2] * shape[3]) { + return error( + "Dimension 1 ` of `{0}` must have a stride equal to size of dimension " + "2 times size of dimension 3 (={1}), but has {2}.", + name, shape[2] * shape[3], strides[1]); + } + + return layout; +} + +// This is the size of the M dimension in all wgmma instructions. It is fixed, +// unlike the K and N dimensions. +constexpr int kWgmmaSizeM = 64; +} // namespace + +llvm::LogicalResult WGMMAOp::verify() { + auto error = [this](auto... params) { + return emitOpError(llvm::formatv(params...)); + }; + + auto a_shaped_type = mlir::cast(getA().getType()); + mlir::Type element_type = a_shaped_type.getElementType(); + if (element_type != getB().getType().getElementType()) { + return error("The `a` and `b` inputs must have the same element type."); + } + + auto b_shape = getB().getType().getShape(); + if (b_shape.size() != 4) { + return error("The `b` input must have rank 4."); + } + + int element_bytewidth = element_type.getIntOrFloatBitWidth() / 8; + int kn_tile = static_cast(getSwizzle()) / element_bytewidth; + + int64_t groups_k = b_shape[0]; + int64_t groups_n = b_shape[1]; + int64_t k_group_size = b_shape[2]; + int64_t n_group_size = b_shape[3]; + + // It might be possible to relax that requirement, in particular to allow + // n_group_size to be smaller than kn_tile and use padding. + if (n_group_size != kn_tile) { + return error( + "The n group size ({0}) must be equal to swizzle/element_bytewidth " + "({1}).", + n_group_size, kn_tile); + } + if (k_group_size != kn_tile) { + return error( + "The k group size ({0}) must be equal to swizzle/element_bytewidth " + "({1}).", + k_group_size, kn_tile); + } + + auto b_layout = GetWgmmaLayout(getLoc(), getB().getType(), "b", getSwizzle()); + if (failed(b_layout)) { + return b_layout; + } + + int groups_m = 0; + auto a_shape = a_shaped_type.getShape(); + if (auto a_memref = dyn_cast(getA().getType())) { + if (a_shape.size() != 4) { + return error("When `a` is a memref, it must have rank 4."); + } + + groups_m = a_shape[0]; + + if (a_shape[1] != groups_k) { + return error( + "When `a` is a memref, dimension 1 ({0}) must be equal to groups_k " + "which is `b`'s dimension 0 ({1}).", + a_shape[1], groups_k); + } + + if (a_shape[2] != kWgmmaSizeM) { + return error( + "When `a` is a memref, dimension 2 ({0}) must be equal to {1}.", + a_shape[2], kWgmmaSizeM); + } + + if (a_shape[3] != kn_tile) { + return error( + "When `a` is a memref, dimension 3 ({0}) must be equal to kn_tile.", + a_shape[3]); + } + + auto a_layout = GetWgmmaLayout(getLoc(), a_memref, "a", getSwizzle()); + if (failed(a_layout)) { + return a_layout; + } + if (*a_layout == WGMMALayout::ColumnMajor && + getSwizzle() != SwizzlingMode::k128ByteSwizzle) { + // Not sure what the layout is like, since the tiles aren't square. + return error( + "When `a` is a memref and has a column-major layout, only a swizzle " + "of 128 bytes is currently supported, but got {0}."); + } + } else { + // a is a tensor in registers. + if (!element_type.isBF16() && !element_type.isF16()) { + return error( + "When `a` is a tensor in registers, it must have element type bf16 " + "or f16."); + } + if (a_shape.size() != 2) { + return error("When `a` is a tensor in registers, it must have rank 2."); + } + if (a_shape[0] % kWgmmaSizeM) { + return error( + "When `a` is a tensor in registers, dimension 0 must be a multiple " + "of {0}, but got {1}.", + kWgmmaSizeM, a_shape[0]); + } + + groups_m = a_shape[0] / kWgmmaSizeM; + + if (a_shape[1] != kn_tile * groups_k) { + return error( + "When `a` is a tensor in registers, dimension 1 must be equal to " + "kn_tile * groups_k ({0}*{1}), but got {2}.", + kn_tile, groups_k, a_shape[1]); + } + } + + auto accShape = getAccumulator().getType().getShape(); + if (accShape.size() != 2) { + return error("The accumulator must have rank 2."); + } + int expected_acc_0 = groups_m * kWgmmaSizeM; + int expected_acc_1 = groups_n * n_group_size; + if (accShape[0] != expected_acc_0 || accShape[1] != expected_acc_1) { + return error( + "Incorrect accumulator shape. Expected: [{0},{1}], but got [{2},{3}].", + expected_acc_0, expected_acc_1, accShape[0], accShape[1]); + } + + return llvm::success(); +} + void MosaicGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 4129dcd1b345..e0eac97beb97 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -120,16 +120,18 @@ def MosaicGPU_DimensionAttr : EnumAttr, - I32EnumAttrCase<"k32ByteSwizzle", 1, "32">, - I32EnumAttrCase<"k64ByteSwizzle", 2, "64">, - I32EnumAttrCase<"k128ByteSwizzle", 3, "128"> + I32EnumAttrCase<"kNoSwizzle", 16, "swizzle_none">, + I32EnumAttrCase<"k32ByteSwizzle", 32, "swizzle_32">, + I32EnumAttrCase<"k64ByteSwizzle", 64, "swizzle_64">, + I32EnumAttrCase<"k128ByteSwizzle", 128, "swizzle_128"> ]>{ let cppNamespace = "::mosaic_gpu"; let genSpecializedAttr = 0; } -def MosaicGPU_SwizzlingModeAttr : EnumAttr; +def MosaicGPU_SwizzlingModeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> { let parameters = (ins Variadic:$tiling); @@ -276,4 +278,78 @@ def MosaicGPU_AsyncStoreOp : Op; + +def MosaicGPU_WGMMALayout : + I32EnumAttr<"WGMMALayout", "The layout of the tiles of a WGMMA operation", [ + I32EnumAttrCase<"RowMajor", 0>, + I32EnumAttrCase<"ColumnMajor", 1> + ]> { + let cppNamespace = "::mosaic_gpu"; + let genSpecializedAttr = 0; +} + +def MosaicGPU_WGMMAOp : Op { + let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations."; + let description = [{ + Schedules WGMMA operations that perform the following matrix multiple and + accumulate: + + accumulator = a * b + accumulator + + This operation supports larger inputs than the PTX-level WGMMA operation + and will schedule as many PTX-level WGMMA operations as needed to + accomplish the calculation. The `b` matrix, and optionally `a`, needs to be + provided in a 4-dimensional form, where the two minor-most dimensions + express the tile (group) size and the two major-most dimensions represent + the total number of tiles in each direction. + + The inputs should have the following shapes: + - If `a` is in shared memory: + - a: [groups_m, groups_k, 64, k] + - If `a` is in registers: + - a: [groups_m * 64, groups_k * k] + - b: [groups_k, groups_n, k, k] + - accumulator: [groups_m * 64, groups_n * k] + Where: + - `k == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.) + + The `accumulator` is always in registers and `b` is always in shared memory. + The last two dimensions of any input in shared memory may be physically + transposed in memory. This is inferred from the strides of the provided + memrefs. `a` and `b` must have the same element type and when `a` is in + registers only F16 or BF16 are supported. + + The `accumulator` must be a tensor with a FragmentedLayout. The WGMMA + operation will be executed in the async proxy and any inputs in + registers need to be synchronized with a memory fence. + + Usually `a` is read from shared memory if it is used directly in the WGMMA + operation. If `a` needs to be transfromed before it is used in the WGMMA + operation, it may be more convenient to read it directly form registers. + This avoids the need to store the data and wait for a fence. + }]; + + let arguments = (ins + TensorOf<[MosaicGPU_WGMMASupportedType]>:$accumulator, + AnyTypeOf<[ + MemRefOf<[MosaicGPU_WGMMASupportedType]>, + TensorOf<[MosaicGPU_WGMMASupportedType]>]>:$a, + MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b, + + // Attributes + DefaultValuedAttr:$swizzle + ); + + let assemblyFormat = [{ + `accumulator` `(` $accumulator `:` type($accumulator) `)` + `a` `(` $a `:` type($a) `)` + `b` `(` $b `:` type($b) `)` + attr-dict + }]; + + let hasVerifier = 1; +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 3edddaad9d12..2c9de0058a67 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -475,6 +475,98 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self): ): self.module.operation.verify() + def test_wgmma_types_match(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.RankedTensorType.get([128, 160], ir.BF16Type.get()), + ir.MemRefType.get([2, 4, 64, 32], ir.F16Type.get()), + ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()), + name="wgmma", + )( + lambda accumulator, a, b: mgpu.wgmma( + accumulator, + a, + b, + swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle"), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `a` and `b` inputs must have the same element type.", + ): + self.module.operation.verify() + + def test_wgmma_b_rank_is_4(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.RankedTensorType.get([128, 160], ir.BF16Type.get()), + ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()), + ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()), + name="wgmma", + )( + lambda accumulator, a, b: mgpu.wgmma( + accumulator, + a, + b, + swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle"), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `b` input must have rank 4.", + ): + self.module.operation.verify() + + def test_wgmma_b_shape_dim_3(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.RankedTensorType.get([128, 160], ir.BF16Type.get()), + ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()), + ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()), + name="wgmma", + )( + lambda accumulator, a, b: mgpu.wgmma( + accumulator, + a, + b, + swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle"), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The n group size \(16\) must be equal to swizzle/element_bytewidth " + r"\(32\)", + ): + self.module.operation.verify() + + def test_wgmma_b_shape_dim_2(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.RankedTensorType.get([128, 160], ir.BF16Type.get()), + ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()), + ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()), + name="wgmma", + )( + lambda accumulator, a, b: mgpu.wgmma( + accumulator, + a, + b, + swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle"), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The k group size \(64\) must be equal to swizzle/element_bytewidth " + r"\(32\)", + ): + self.module.operation.verify() + + # TODO(b/381371456): Add tests for the other WGMMA inputs. + class DialectLoweringTest(DialectTest):