Skip to content

Commit

Permalink
[XLA:Mosaic] Support memref shapecast.
Browse files Browse the repository at this point in the history
This cl supports memref shapecast:
1. if tile is (1, 128), we support shapecast on any dim.
2. if shapecast on sublane dim, we only support tile aligned shape.
3. if shapecast on non-tiling dim, we support any shapecast.
4. all other cases would be considered as invalid memref shapecast.

PiperOrigin-RevId: 651924552
  • Loading branch information
bythew3i authored and jax authors committed Jul 13, 2024
1 parent f3c1cbc commit aa16485
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 19 deletions.
3 changes: 2 additions & 1 deletion jaxlib/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ cc_library(
"dialect/tpu/layout.cc",
"dialect/tpu/tpu_dialect.cc",
"dialect/tpu/tpu_ops.cc",
"dialect/tpu/util.h",
"dialect/tpu/util.cc",
] + glob([
"dialect/tpu/transforms/*.cc",
]),
hdrs = [
"dialect/tpu/layout.h",
"dialect/tpu/tpu_dialect.h",
"dialect/tpu/util.h",
] + glob([
"dialect/tpu/transforms/*.h",
]),
Expand Down
10 changes: 10 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,16 @@ def TPU_MemRefSqueezeOp : TPU_Op<"memref_squeeze", [Pure]> {
let hasCanonicalizeMethod = 1;
}

def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> {
let arguments = (ins AnyMemRef:$input);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($result)
}];
let hasVerifier = 1;
let hasCanonicalizeMethod = 1;
}

def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> {
let arguments = (ins AnyMemRef:$input);
let results = (outs AnyMemRef:$result);
Expand Down
92 changes: 92 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/IRMapping.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"

namespace mlir {
namespace tpu {
Expand Down Expand Up @@ -183,6 +186,95 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op,
return success();
}

LogicalResult MemRefReshapeOp::verify() {
auto src_ty = getMemRefType(getInput());
auto tgt_ty = getType();
if (tgt_ty.getMemorySpace() != nullptr &&
tgt_ty.getMemorySpace() != src_ty.getMemorySpace()) {
return emitOpError("Memory spaces do not match.");
}
if (src_ty.getShape().size() < 2 || tgt_ty.getShape().size() < 2) {
return emitError("Not implemented: 1d memref reshape.");
}
if (tgt_ty.getElementType() != src_ty.getElementType()) {
return emitOpError("Element types don't match.");
}
auto src_elements_num = ShapedType::getNumElements(src_ty.getShape());
auto tgt_elements_num = ShapedType::getNumElements(tgt_ty.getShape());
if (src_elements_num != tgt_elements_num) {
return emitOpError(
"Number of elements doesn't match between input and output memref "
"type.");
}
// Source and target attributes may be different before propagation is done by
// the canonicalizer, so we allow this when attributes are "unset" in the
// target type.
auto tgt_layout = dyn_cast<tpu::TiledLayoutAttr>(tgt_ty.getLayout());
if (!tgt_layout) {
return success();
}
auto src_layout = dyn_cast<tpu::TiledLayoutAttr>(src_ty.getLayout());
if (!src_layout || src_layout.getTiles().empty()) {
return emitOpError("Expected a tiled layout for the input memref.");
}
if (src_layout.getTiles() != tgt_layout.getTiles()) {
return emitOpError(
"Expected the same tiling for the input and output memref.");
}
auto tile = src_layout.getTiles().front().dimensions();
if (tile.size() != 2) {
return emitOpError("Not implemented: memref reshape with 1D tiling.");
}
SmallVector<int64_t> src_tile_strides(src_layout.getTileStrides());
if (ComputeTileStrides(src_ty, tile) != src_tile_strides) {
return emitOpError("Not implemented: reshape on a non-contiguous memref.");
}
auto src_tiled_shape = src_ty.getShape().take_back(2);
auto tgt_tiled_shape = tgt_ty.getShape().take_back(2);
bool is_src_align_tile_2nd_minor = src_tiled_shape[0] % tile[0] == 0;
bool is_src_align_tile_minor = src_tiled_shape[1] % tile[1] == 0;
bool is_tgt_align_tile_2nd_minor = tgt_tiled_shape[0] % tile[0] == 0;
bool is_tgt_align_tile_minor = tgt_tiled_shape[1] % tile[1] == 0;
if (tile[0] == 1 && is_src_align_tile_minor && is_tgt_align_tile_minor) {
// When the tiling is (1, ?) and the source and target shapes are aligned
// to the tile, we support reshape on any dims.
} else if (tgt_tiled_shape[1] != src_tiled_shape[1]) {
return emitError("Expected the minormost dimension to be unchanged");
} else if (tgt_tiled_shape[0] != src_tiled_shape[0]) {
if (!is_src_align_tile_2nd_minor || !is_tgt_align_tile_2nd_minor) {
return emitError(
"Expected the 2nd minor dimension is aligned to the tile");
}
}
return success();
}

LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op,
PatternRewriter &rewriter) {
auto src_ty = op.getInput().getType();
auto dst_ty = op.getType();
auto erase_layout_op = op.getInput().getDefiningOp<tpu::EraseLayoutOp>();
if (!erase_layout_op) {
return failure();
}
auto layout_ref = erase_layout_op.getOperand();
auto layout_ty = layout_ref.getType();
auto layout =
dyn_cast<tpu::TiledLayoutAttr>(layout_ty.getLayout());
CHECK(!layout.getTiles().empty());
auto tile = layout.getTiles().front().dimensions();
auto new_tile_strides = ComputeTileStrides(dst_ty, tile);
auto new_layout = tpu::TiledLayoutAttr::get(
src_ty.getContext(), layout.getTiles(), new_tile_strides);
auto new_result_ty =
MemRefType::get(dst_ty.getShape(), dst_ty.getElementType(), new_layout,
layout_ty.getMemorySpace());
auto reshape =
rewriter.create<MemRefReshapeOp>(op.getLoc(), new_result_ty, layout_ref);
rewriter.replaceOpWithNewOp<EraseLayoutOp>(op, op.getType(), reshape);
return success();
}

template <typename Op>
LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
VectorType vector_ty) {
Expand Down
19 changes: 1 addition & 18 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/layout.h"
Expand All @@ -34,22 +33,6 @@ namespace mlir::tpu {
#define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
int64_t leading_tile_rows) {
SmallVector<int64_t> tile_strides(memref_ty.getRank());
int64_t stride = 1;
for (int i = memref_ty.getRank() - 1; i >= 0; --i) {
tile_strides[i] = stride;
if (i == memref_ty.getRank() - 1) {
stride *= llvm::divideCeil(memref_ty.getShape()[i], 128);
} else if (i == memref_ty.getRank() - 2) {
stride *= llvm::divideCeil(memref_ty.getShape()[i], leading_tile_rows);
} else {
stride *= memref_ty.getShape()[i];
}
}
return tile_strides;
}

// Returns the number of 128-element groups in a tile.
//
Expand Down Expand Up @@ -151,7 +134,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
}
tiles.push_back(xla::Tile({32 / bitwidth, 1}));
}
auto tile_strides = ComputeTileStrides(memref_ty, leading_tile_rows);
auto tile_strides = ComputeTileStrides(memref_ty, {leading_tile_rows, 128});
return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides);
}
return emitError(UnknownLoc::get(memref_ty.getContext()),
Expand Down
42 changes: 42 additions & 0 deletions jaxlib/mosaic/dialect/tpu/util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "jaxlib/mosaic/dialect/tpu/util.h"

#include <cstdint>

#include "llvm/Support/MathExtras.h"
#include "absl/types/span.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/Support/LLVM.h"

namespace mlir::tpu {
SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
absl::Span<const int64_t> tiling) {
SmallVector<int64_t> tile_strides(memref_ty.getRank());
int64_t stride = 1;
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
int64_t idx = memref_ty.getRank() - 1 - i;
int64_t tiling_idx = tiling.size() - 1 - i;
tile_strides[idx] = stride;
if (tiling_idx >= 0) {
stride *= llvm::divideCeil(memref_ty.getShape()[idx], tiling[tiling_idx]);
} else {
stride *= memref_ty.getShape()[idx];
}
}
return tile_strides;
}
} // namespace mlir::tpu
2 changes: 2 additions & 0 deletions jaxlib/mosaic/dialect/tpu/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ std::string shapeToString(const T &shape) {
return os.str();
}

SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
absl::Span<const int64_t> tiling);
} // namespace mlir::tpu

#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_

0 comments on commit aa16485

Please sign in to comment.