Skip to content

Commit

Permalink
[XLA:Mosaic] Create apply layout pass with ctx instead of config list.
Browse files Browse the repository at this point in the history
This cl removes the funcOp from RewriteContext of apply-vector-layout-pass (since only one function is using it) and uses context to create the pass instead of a long list of arguments. We will need to add more args (target's bank counts) to create apply-vector-layout.

PiperOrigin-RevId: 655329321
  • Loading branch information
bythew3i authored and jax authors committed Jul 23, 2024
1 parent 101e5fe commit 7e2107b
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 76 deletions.
6 changes: 3 additions & 3 deletions jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ MlirTpuValueArray mlirTpuDisassemble(MlirTpuInsertionPoint insertion_point,
MlirLogicalResult mlirTpuApplyLayoutOp(int hardware_generation,
MlirOperation op,
MlirTpuI64TargetTuple target_shape) {
auto f = unwrap(op)->getParentOfType<mlir::func::FuncOp>();
CHECK(f != nullptr);
mlir::tpu::RewriteContext ctx{f, hardware_generation, unwrap(target_shape)};
mlir::tpu::ApplyVectorLayoutContext ctx{
.hardware_generation = hardware_generation,
.target_shape = unwrap(target_shape)};
return wrap(mlir::tpu::applyLayoutOp(ctx, *unwrap(op)));
}

Expand Down
2 changes: 1 addition & 1 deletion jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO
"::mlir::vector::VectorDialect",
"::mlir::tpu::TPUDialect",
];
let constructor = "::mlir::tpu::createApplyVectorLayoutPass(-1)";
let constructor = "::mlir::tpu::createApplyVectorLayoutPass()";
let options = [
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Expand Down
14 changes: 11 additions & 3 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef JAXLIB_MOSAIC_DIALECT_TPU_DIALECT_H_
#define JAXLIB_MOSAIC_DIALECT_TPU_DIALECT_H_

#include <array>
#include <cstdint>
#include <memory>
#include <utility>
Expand Down Expand Up @@ -54,6 +55,15 @@ struct TpuTilingFlags {
bool use_x4_large_second_minor = false;
};

struct ApplyVectorLayoutContext {
// TODO(tlongeri): target_shape should be determined from hardware_generation
int hardware_generation = -1;
std::array<int64_t, 2> target_shape = {8, 128};
// mxu_shape = {contracting_size, non_contracting_size}
std::array<int64_t, 2> mxu_shape = {128, 128};
int64_t max_sublanes_in_scratch = 0;
};

std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);

std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
Expand All @@ -66,9 +76,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int lane_count = 128, int sublane_count = 8);

std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
int hardware_generation = -1, int lane_count = 128, int sublane_count = 8,
int mxu_contracting_size = 128, int mxu_noncontracting_size = 128,
int max_sublanes_in_scratch = 0);
const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{});

std::unique_ptr<OperationPass<func::FuncOp>>
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
Expand Down
107 changes: 52 additions & 55 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
Expand Down Expand Up @@ -118,6 +119,8 @@ namespace mlir::tpu {
// TODO(jevinjiang): need to update it based on the generation.
static constexpr int kMinBoundToRotateWithScratch = 27;

using RewriteContext = ApplyVectorLayoutContext;

LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block);
namespace {

Expand Down Expand Up @@ -501,18 +504,16 @@ FailureOr<std::array<int64_t, 2>> getMemRefTiling(
}

// Hoist a vector constant as an additional argument of the function.
FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
DenseElementsAttr value) {
MLIRContext *mlir_ctx = ctx.func.getContext();
Block &entry_block = ctx.func.getBody().front();
MLIRContext *mlir_ctx = func.getContext();
Block &entry_block = func.getBody().front();
auto value_ty = cast<VectorType>(value.getType());
if (value_ty.getElementType().getIntOrFloatBitWidth() != 32) {
return ctx.func.emitOpError(
"Not implemented: Only 32-bit constants supported");
return func.emitOpError("Not implemented: Only 32-bit constants supported");
}
if (ctx.func->getAttr("scratch_operands")) {
return ctx.func.emitOpError(
"Not implemented: function has scratch_operands");
if (func->getAttr("scratch_operands")) {
return func.emitOpError("Not implemented: function has scratch_operands");
}
// We can omit tpu_tiling_flags here since we invoke inferMemref only for
// constant operands which are kernel parameters that will have their layouts
Expand All @@ -522,46 +523,42 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
inferMemref(
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
ctx.hardware_generation, /*tpu_tiling_flags=*/{}));
const BlockArgument argument =
entry_block.insertArgument(entry_block.getNumArguments() - 1, arg_type,
UnknownLoc::get(ctx.getMLIRContext()));
const FunctionType func_ty = ctx.func.getFunctionType();
const BlockArgument argument = entry_block.insertArgument(
entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx));
const FunctionType func_ty = func.getFunctionType();
// Adjust the function type.
SmallVector<Type> new_arg_tys(func_ty.getInputs());
new_arg_tys.insert(new_arg_tys.begin() + (new_arg_tys.size() - 1), arg_type);
const auto new_func_ty =
FunctionType::get(mlir_ctx, new_arg_tys, func_ty.getResults());
ctx.func.setFunctionType(new_func_ty);
func.setFunctionType(new_func_ty);
// Adjust the constants attribute.
if (auto prev_cst = ctx.func->getAttrOfType<ArrayAttr>("vector_constants")) {
if (auto prev_cst = func->getAttrOfType<ArrayAttr>("vector_constants")) {
SmallVector<Attribute> vector_constants(prev_cst.getValue());
vector_constants.push_back(value);
ctx.func->setAttr("vector_constants",
ArrayAttr::get(ctx.func.getContext(), vector_constants));
func->setAttr("vector_constants",
ArrayAttr::get(func.getContext(), vector_constants));
} else {
ctx.func->setAttr("vector_constants",
ArrayAttr::get(ctx.func.getContext(), value));
func->setAttr("vector_constants", ArrayAttr::get(func.getContext(), value));
}
// Adjust window params for the extra operand.
if (auto window_params =
ctx.func->getAttrOfType<ArrayAttr>("window_params")) {
if (auto window_params = func->getAttrOfType<ArrayAttr>("window_params")) {
const auto iteration_bounds =
ctx.func->getAttrOfType<DenseI64ArrayAttr>("iteration_bounds");
func->getAttrOfType<DenseI64ArrayAttr>("iteration_bounds");
TPU_ASSERT_LOC(UnknownLoc::get(mlir_ctx), iteration_bounds);
const int64_t iteration_rank = iteration_bounds.getSize();
const SmallVector<AffineExpr> zeros(
iteration_rank, getAffineConstantExpr(0, ctx.func.getContext()));
iteration_rank, getAffineConstantExpr(0, func.getContext()));
const auto transform_indices =
AffineMap::get(iteration_rank, 0, zeros, ctx.func.getContext());
AffineMap::get(iteration_rank, 0, zeros, func.getContext());
const auto new_param = DictionaryAttr::get(
ctx.func.getContext(),
NamedAttribute(
StringAttr::get(ctx.func.getContext(), "transform_indices"),
AffineMapAttr::get(transform_indices)));
func.getContext(),
NamedAttribute(StringAttr::get(func.getContext(), "transform_indices"),
AffineMapAttr::get(transform_indices)));
SmallVector<Attribute> window_params_values(window_params.getValue());
window_params_values.insert(window_params_values.end() - 1, new_param);
ctx.func->setAttr("window_params", ArrayAttr::get(ctx.func.getContext(),
window_params_values));
func->setAttr("window_params",
ArrayAttr::get(func.getContext(), window_params_values));
}
return argument;
}
Expand Down Expand Up @@ -2666,9 +2663,14 @@ LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op,
for (int64_t i = 0; i < ctx.target_shape[0]; ++i) { // Broadcast
dyn_ix_val.append(segment_indices);
}
auto func_op = op.getParentOfType<func::FuncOp>();
if (!func_op) {
return op.emitOpError("Expected a function op");
}
FAILUREOR_ASSIGN_OR_RETURN(
const BlockArgument dyn_ix_ref,
appendConstant(ctx, DenseIntElementsAttr::get(dyn_ix_ty, dyn_ix_val)));
appendConstant(ctx, func_op,
DenseIntElementsAttr::get(dyn_ix_ty, dyn_ix_val)));
auto all_sublanes = builder.getAttr<DenseBoolArrayAttr>(
SmallVector<bool>(ctx.target_shape[1], true));
auto dyn_ix = builder.create<tpu::LoadOp>(
Expand Down Expand Up @@ -3037,8 +3039,12 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
return op.emitOpError(
"Not implemented: Only 32-bit non-splat constants are supported");
}
auto func_op = op.getParentOfType<func::FuncOp>();
if (!func_op) {
return op.emitOpError("Expected a function op");
}
FAILUREOR_ASSIGN_OR_RETURN(const BlockArgument ref,
appendConstant(ctx, value));
appendConstant(ctx, func_op, value));
auto load_op = builder.create<vector::LoadOp>(
vty, ref,
SmallVector<Value>(vty.getRank(), IdxConst(0, builder, op.getLoc())));
Expand Down Expand Up @@ -5564,43 +5570,34 @@ LogicalResult applyLayoutFunc(RewriteContext &ctx, func::FuncOp f) {

struct ApplyVectorLayoutPass
: public impl::ApplyVectorLayoutPassBase<ApplyVectorLayoutPass> {
ApplyVectorLayoutPass(int hardware_generation_, int lane_count_,
int sublane_count_, int mxu_contracting_size_,
int mxu_noncontracting_size_,
int max_sublanes_in_scratch_) {
hardware_generation = hardware_generation_;
sublane_count = sublane_count_;
lane_count = lane_count_;
mxu_contracting_size = mxu_contracting_size_;
mxu_noncontracting_size = mxu_noncontracting_size_;
max_sublanes_in_scratch = max_sublanes_in_scratch_;
ApplyVectorLayoutPass(const RewriteContext &ctx) {
hardware_generation = ctx.hardware_generation;
sublane_count = ctx.target_shape[0];
lane_count = ctx.target_shape[1];
mxu_contracting_size = ctx.mxu_shape[0];
mxu_noncontracting_size = ctx.mxu_shape[1];
max_sublanes_in_scratch = ctx.max_sublanes_in_scratch;
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
if (hardware_generation < 0) {
signalPassFailure();
return;
}
func::FuncOp func = getOperation();
RewriteContext ctx{func,
hardware_generation,
{sublane_count, lane_count},
{mxu_contracting_size, mxu_noncontracting_size},
max_sublanes_in_scratch};
if (failed(applyLayoutFunc(ctx, func))) {
RewriteContext ctx{
.hardware_generation = hardware_generation,
.target_shape = {sublane_count, lane_count},
.mxu_shape = {mxu_contracting_size, mxu_noncontracting_size},
.max_sublanes_in_scratch = max_sublanes_in_scratch};
if (failed(applyLayoutFunc(ctx, getOperation()))) {
signalPassFailure();
return;
}
}
};

std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
int hardware_generation, int lane_count, int sublane_count,
int mxu_contracting_size, int mxu_noncontracting_size,
int max_sublanes_in_scratch) {
return std::make_unique<ApplyVectorLayoutPass>(
hardware_generation, lane_count, sublane_count, mxu_contracting_size,
mxu_noncontracting_size, max_sublanes_in_scratch);
const RewriteContext &ctx) {
return std::make_unique<ApplyVectorLayoutPass>(ctx);
}

} // namespace mlir::tpu
15 changes: 1 addition & 14 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
#include <array>
#include <cstdint>

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
Expand All @@ -15,17 +13,6 @@

namespace mlir::tpu {

struct RewriteContext {
func::FuncOp func;
// TODO(tlongeri): target_shape should be determined from hardware_generation
const int hardware_generation;
const std::array<int64_t, 2> target_shape = {8, 128};
const std::array<int64_t, 2> mxu_shape = {128, 128};
const int max_sublanes_in_scratch = 0;

MLIRContext *getMLIRContext() { return func.getContext(); }
};

// TODO(tlongeri): Remove default values for use_implicit_shape.
RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
const VectorLayout &layout,
Expand All @@ -50,7 +37,7 @@ FailureOr<xla::Array<Value>> disassemble(OpBuilder &builder,
// and
// have a valid layout (Layout1D or Layout2D)
// - All non-vector operands must have NoLayout.
LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op);
LogicalResult applyLayoutOp(ApplyVectorLayoutContext &ctx, Operation &op);

// Changes the layout of a vector value.
//
Expand Down

0 comments on commit 7e2107b

Please sign in to comment.