Skip to content

Commit

Permalink
[Mosaic TPU] Add support for the interleaved pack format to tpu.unpac…
Browse files Browse the repository at this point in the history
…k_subelements

PiperOrigin-RevId: 707142562
  • Loading branch information
apaszke authored and Google-ML-Automation committed Dec 17, 2024
1 parent 36b12d5 commit 4911a39
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
let arguments = (ins
AnyVectorOfNonZeroRank:$source,
I32Attr:$index
I32Attr:$index,
TPU_PackFormatEnum:$pack_format
);
let results = (outs AnyVectorOfNonZeroRank:$output);
let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }];
Expand Down
12 changes: 8 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,8 @@ FailureOr<xla::Array<Value>> ext_op_rule_impl(RewriteContext &ctx,
int64_t vreg_part = *(input_vreg_idxs.end() - 2) % packing;
*(input_vreg_idxs.end() - 2) /= packing;
*v = builder.create<UnpackSubelementsOp>(
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part,
tpu::PackFormat::kCompressed);
});
} else {
if (layout_in.tiling() != layout_out.tiling()) {
Expand All @@ -890,7 +891,8 @@ FailureOr<xla::Array<Value>> ext_op_rule_impl(RewriteContext &ctx,
input_vreg_idxs.back() /= packing;
const int64_t vreg_part = idxs.back() % packing;
*v = builder.create<UnpackSubelementsOp>(
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part,
tpu::PackFormat::kCompressed);
});
}
return output_vregs;
Expand Down Expand Up @@ -6265,7 +6267,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
src_idx[src_idx.size() - 1] /= packing;
for (int i = 0; i < packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
if (src_idx[src_idx.size() - 2] <
vregs.dim(vregs.num_dimensions() - 2) - 1) {
++src_idx[src_idx.size() - 2];
Expand Down Expand Up @@ -6345,7 +6348,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
*(src_idx.end() - 1) /= packing;
for (int i = 0; i < packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
loc, vreg_x32, vregs(src_idx), vreg_part,
tpu::PackFormat::kCompressed));
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2) - 1) {
++*(src_idx.end() - 2);
} // The rest is padding, so just pick any of the input parts (but not
Expand Down

0 comments on commit 4911a39

Please sign in to comment.