Skip to content

Commit

Permalink
[Mosaic GPU] Add WGMMA to the Mosaic GPU MLIR Dialect.
Browse files Browse the repository at this point in the history
The op API is still in flux so I'm leaving some of the verification code untested.

PiperOrigin-RevId: 705020066
  • Loading branch information
dimitar-asenov authored and Google-ML-Automation committed Dec 11, 2024
1 parent cfdac00 commit 66f45d0
Show file tree
Hide file tree
Showing 3 changed files with 353 additions and 5 deletions.
180 changes: 180 additions & 0 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Diagnostics.h"
#include "tsl/platform/statusor.h"

Expand Down Expand Up @@ -311,6 +313,184 @@ llvm::LogicalResult AsyncStoreOp::verify() {
getSliceLengths(), getIndices().size());
}

namespace {
llvm::FailureOr<WGMMALayout> GetWgmmaLayout(mlir::Location loc,
mlir::MemRefType type,
absl::string_view name,
SwizzlingMode swizzling_mode) {
auto error = [loc](auto... params) {
return emitError(loc, llvm::formatv(params...));
};

auto [strides, offset] = mlir::getStridesAndOffset(type);

WGMMALayout layout = WGMMALayout::RowMajor;
if (strides[3] == 1) {
layout = WGMMALayout::RowMajor;
} else if (strides[2] == 1) {
layout = WGMMALayout::ColumnMajor;
} else {
return error(
"At least one of the last two dimensions of `{0}` must have a "
"stride of 1, but they do not: stride(dim 2)={1}, stride(dim 3)={2}",
name, strides[2], strides[3]);
}

auto shape = type.getShape();
if (layout == WGMMALayout::RowMajor && strides[2] != shape[3]) {
return error(
"When `{0}` has row-major layout, the stride of dimension 2 (={1}) "
"must be equal to size of dimension 3 (={2})",
shape[3], strides[2], shape[3]);
}

if (layout == WGMMALayout::ColumnMajor && strides[3] != shape[2]) {
return error(
"When `{0}` has column-major layout, the stride of dimension 3 (={1}) "
"must be equal to size of dimension 2 (={2})",
shape[2], strides[3], shape[2]);
}

if (strides[1] != shape[2] * shape[3]) {
return error(
"Dimension 1 ` of `{0}` must have a stride equal to size of dimension "
"2 times size of dimension 3 (={1}), but has {2}.",
name, shape[2] * shape[3], strides[1]);
}

return layout;
}

// This is the size of the M dimension in all wgmma instructions. It is fixed,
// unlike the K and N dimensions.
constexpr int kWgmmaSizeM = 64;
} // namespace

llvm::LogicalResult WGMMAOp::verify() {
auto error = [this](auto... params) {
return emitOpError(llvm::formatv(params...));
};

auto a_shaped_type = mlir::cast<mlir::ShapedType>(getA().getType());
mlir::Type element_type = a_shaped_type.getElementType();
if (element_type != getB().getType().getElementType()) {
return error("The `a` and `b` inputs must have the same element type.");
}

auto b_shape = getB().getType().getShape();
if (b_shape.size() != 4) {
return error("The `b` input must have rank 4.");
}

int element_bytewidth = element_type.getIntOrFloatBitWidth() / 8;
int kn_tile = static_cast<int>(getSwizzle()) / element_bytewidth;

int64_t groups_k = b_shape[0];
int64_t groups_n = b_shape[1];
int64_t k_group_size = b_shape[2];
int64_t n_group_size = b_shape[3];

// It might be possible to relax that requirement, in particular to allow
// n_group_size to be smaller than kn_tile and use padding.
if (n_group_size != kn_tile) {
return error(
"The n group size ({0}) must be equal to swizzle/element_bytewidth "
"({1}).",
n_group_size, kn_tile);
}
if (k_group_size != kn_tile) {
return error(
"The k group size ({0}) must be equal to swizzle/element_bytewidth "
"({1}).",
k_group_size, kn_tile);
}

auto b_layout = GetWgmmaLayout(getLoc(), getB().getType(), "b", getSwizzle());
if (failed(b_layout)) {
return b_layout;
}

int groups_m = 0;
auto a_shape = a_shaped_type.getShape();
if (auto a_memref = dyn_cast<mlir::MemRefType>(getA().getType())) {
if (a_shape.size() != 4) {
return error("When `a` is a memref, it must have rank 4.");
}

groups_m = a_shape[0];

if (a_shape[1] != groups_k) {
return error(
"When `a` is a memref, dimension 1 ({0}) must be equal to groups_k "
"which is `b`'s dimension 0 ({1}).",
a_shape[1], groups_k);
}

if (a_shape[2] != kWgmmaSizeM) {
return error(
"When `a` is a memref, dimension 2 ({0}) must be equal to {1}.",
a_shape[2], kWgmmaSizeM);
}

if (a_shape[3] != kn_tile) {
return error(
"When `a` is a memref, dimension 3 ({0}) must be equal to kn_tile.",
a_shape[3]);
}

auto a_layout = GetWgmmaLayout(getLoc(), a_memref, "a", getSwizzle());
if (failed(a_layout)) {
return a_layout;
}
if (*a_layout == WGMMALayout::ColumnMajor &&
getSwizzle() != SwizzlingMode::k128ByteSwizzle) {
// Not sure what the layout is like, since the tiles aren't square.
return error(
"When `a` is a memref and has a column-major layout, only a swizzle "
"of 128 bytes is currently supported, but got {0}.");
}
} else {
// a is a tensor in registers.
if (!element_type.isBF16() && !element_type.isF16()) {
return error(
"When `a` is a tensor in registers, it must have element type bf16 "
"or f16.");
}
if (a_shape.size() != 2) {
return error("When `a` is a tensor in registers, it must have rank 2.");
}
if (a_shape[0] % kWgmmaSizeM) {
return error(
"When `a` is a tensor in registers, dimension 0 must be a multiple "
"of {0}, but got {1}.",
kWgmmaSizeM, a_shape[0]);
}

groups_m = a_shape[0] / kWgmmaSizeM;

if (a_shape[1] != kn_tile * groups_k) {
return error(
"When `a` is a tensor in registers, dimension 1 must be equal to "
"kn_tile * groups_k ({0}*{1}), but got {2}.",
kn_tile, groups_k, a_shape[1]);
}
}

auto accShape = getAccumulator().getType().getShape();
if (accShape.size() != 2) {
return error("The accumulator must have rank 2.");
}
int expected_acc_0 = groups_m * kWgmmaSizeM;
int expected_acc_1 = groups_n * n_group_size;
if (accShape[0] != expected_acc_0 || accShape[1] != expected_acc_1) {
return error(
"Incorrect accumulator shape. Expected: [{0},{1}], but got [{2},{3}].",
expected_acc_0, expected_acc_1, accShape[0], accShape[1]);
}

return llvm::success();
}

void MosaicGPUDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
Expand Down
86 changes: 81 additions & 5 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,18 @@ def MosaicGPU_DimensionAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_Dimension, "
def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode",
"What swizzling to use for a memory access.",
[
I32EnumAttrCase<"kNoSwizzle", 0, "none">,
I32EnumAttrCase<"k32ByteSwizzle", 1, "32">,
I32EnumAttrCase<"k64ByteSwizzle", 2, "64">,
I32EnumAttrCase<"k128ByteSwizzle", 3, "128">
I32EnumAttrCase<"kNoSwizzle", 16, "swizzle_none">,
I32EnumAttrCase<"k32ByteSwizzle", 32, "swizzle_32">,
I32EnumAttrCase<"k64ByteSwizzle", 64, "swizzle_64">,
I32EnumAttrCase<"k128ByteSwizzle", 128, "swizzle_128">
]>{
let cppNamespace = "::mosaic_gpu";
let genSpecializedAttr = 0;
}

def MosaicGPU_SwizzlingModeAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_SwizzlingMode, "swizzle">;
def MosaicGPU_SwizzlingModeAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_SwizzlingMode, "swizzle"> {
let assemblyFormat = "`<` $value `>`";
}

def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
let parameters = (ins Variadic<I64>:$tiling);
Expand Down Expand Up @@ -276,4 +278,78 @@ def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
let hasVerifier = 1;
}

def MosaicGPU_WGMMASupportedType : AnyTypeOf<[F16, BF16, F32],
"A type supported by the WGMMA operation">;

def MosaicGPU_WGMMALayout :
I32EnumAttr<"WGMMALayout", "The layout of the tiles of a WGMMA operation", [
I32EnumAttrCase<"RowMajor", 0>,
I32EnumAttrCase<"ColumnMajor", 1>
]> {
let cppNamespace = "::mosaic_gpu";
let genSpecializedAttr = 0;
}

def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations.";
let description = [{
Schedules WGMMA operations that perform the following matrix multiple and
accumulate:

accumulator = a * b + accumulator

This operation supports larger inputs than the PTX-level WGMMA operation
and will schedule as many PTX-level WGMMA operations as needed to
accomplish the calculation. The `b` matrix, and optionally `a`, needs to be
provided in a 4-dimensional form, where the two minor-most dimensions
express the tile (group) size and the two major-most dimensions represent
the total number of tiles in each direction.

The inputs should have the following shapes:
- If `a` is in shared memory:
- a: [groups_m, groups_k, 64, k]
- If `a` is in registers:
- a: [groups_m * 64, groups_k * k]
- b: [groups_k, groups_n, k, k]
- accumulator: [groups_m * 64, groups_n * k]
Where:
- `k == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.)

The `accumulator` is always in registers and `b` is always in shared memory.
The last two dimensions of any input in shared memory may be physically
transposed in memory. This is inferred from the strides of the provided
memrefs. `a` and `b` must have the same element type and when `a` is in
registers only F16 or BF16 are supported.

The `accumulator` must be a tensor with a FragmentedLayout. The WGMMA
operation will be executed in the async proxy and any inputs in
registers need to be synchronized with a memory fence.

Usually `a` is read from shared memory if it is used directly in the WGMMA
operation. If `a` needs to be transfromed before it is used in the WGMMA
operation, it may be more convenient to read it directly form registers.
This avoids the need to store the data and wait for a fence.
}];

let arguments = (ins
TensorOf<[MosaicGPU_WGMMASupportedType]>:$accumulator,
AnyTypeOf<[
MemRefOf<[MosaicGPU_WGMMASupportedType]>,
TensorOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b,

// Attributes
DefaultValuedAttr<MosaicGPU_SwizzlingModeAttr, "SwizzlingMode::k128ByteSwizzle">:$swizzle
);

let assemblyFormat = [{
`accumulator` `(` $accumulator `:` type($accumulator) `)`
`a` `(` $a `:` type($a) `)`
`b` `(` $b `:` type($b) `)`
attr-dict
}];

let hasVerifier = 1;
}

#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
92 changes: 92 additions & 0 deletions tests/mosaic/gpu_dialect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,98 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self):
):
self.module.operation.verify()

def test_wgmma_types_match(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
ir.MemRefType.get([2, 4, 64, 32], ir.F16Type.get()),
ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
accumulator,
a,
b,
swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle<swizzle_64>"),
)
)

with self.assertRaisesRegex(
ir.MLIRError,
"The `a` and `b` inputs must have the same element type.",
):
self.module.operation.verify()

def test_wgmma_b_rank_is_4(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
accumulator,
a,
b,
swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle<swizzle_64>"),
)
)

with self.assertRaisesRegex(
ir.MLIRError,
"The `b` input must have rank 4.",
):
self.module.operation.verify()

def test_wgmma_b_shape_dim_3(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
accumulator,
a,
b,
swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle<swizzle_64>"),
)
)

with self.assertRaisesRegex(
ir.MLIRError,
r"The n group size \(16\) must be equal to swizzle/element_bytewidth "
r"\(32\)",
):
self.module.operation.verify()

def test_wgmma_b_shape_dim_2(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
accumulator,
a,
b,
swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle<swizzle_64>"),
)
)

with self.assertRaisesRegex(
ir.MLIRError,
r"The k group size \(64\) must be equal to swizzle/element_bytewidth "
r"\(32\)",
):
self.module.operation.verify()

# TODO(b/381371456): Add tests for the other WGMMA inputs.


class DialectLoweringTest(DialectTest):

Expand Down

0 comments on commit 66f45d0

Please sign in to comment.