From 4911a396b2fdf9ffaf5b6c13f31059f7954f1ae1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 17 Dec 2024 09:57:39 -0800 Subject: [PATCH] [Mosaic TPU] Add support for the interleaved pack format to tpu.unpack_subelements PiperOrigin-RevId: 707142562 --- jaxlib/mosaic/dialect/tpu/tpu.td | 3 ++- .../dialect/tpu/transforms/apply_vector_layout.cc | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index ed5c2f2da263..6df5fe6ac75d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -363,7 +363,8 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { let arguments = (ins AnyVectorOfNonZeroRank:$source, - I32Attr:$index + I32Attr:$index, + TPU_PackFormatEnum:$pack_format ); let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index a3a37e7d7da8..f797d2f39591 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -874,7 +874,8 @@ FailureOr> ext_op_rule_impl(RewriteContext &ctx, int64_t vreg_part = *(input_vreg_idxs.end() - 2) % packing; *(input_vreg_idxs.end() - 2) /= packing; *v = builder.create( - op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part, + tpu::PackFormat::kCompressed); }); } else { if (layout_in.tiling() != layout_out.tiling()) { @@ -890,7 +891,8 @@ FailureOr> ext_op_rule_impl(RewriteContext &ctx, input_vreg_idxs.back() /= packing; const int64_t vreg_part = idxs.back() % packing; *v = builder.create( - op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part, + tpu::PackFormat::kCompressed); }); } return output_vregs; @@ -6265,7 +6267,8 @@ FailureOr>> changeTiling( src_idx[src_idx.size() - 1] /= packing; for (int i = 0; i < packing; ++i) { parts.push_back(builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part)); + loc, vreg_x32, vregs(src_idx), vreg_part, + tpu::PackFormat::kCompressed)); if (src_idx[src_idx.size() - 2] < vregs.dim(vregs.num_dimensions() - 2) - 1) { ++src_idx[src_idx.size() - 2]; @@ -6345,7 +6348,8 @@ FailureOr>> changeTiling( *(src_idx.end() - 1) /= packing; for (int i = 0; i < packing; ++i) { parts.push_back(builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part)); + loc, vreg_x32, vregs(src_idx), vreg_part, + tpu::PackFormat::kCompressed)); if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2) - 1) { ++*(src_idx.end() - 2); } // The rest is padding, so just pick any of the input parts (but not