diff --git a/tests/filecheck/transforms/memref-to-dsd.mlir b/tests/filecheck/transforms/memref-to-dsd.mlir index 410066c436..2c38de193c 100644 --- a/tests/filecheck/transforms/memref-to-dsd.mlir +++ b/tests/filecheck/transforms/memref-to-dsd.mlir @@ -120,6 +120,18 @@ builtin.module { // CHECK-NEXT: %31 = memref.load %b[%13] : memref<510xf32> // CHECK-NEXT: "test.op"(%31) : (f32) -> () +%39 = memref.alloc() {"alignment" = 64 : i64} : memref<3x64xf32> +%40 = "memref.subview"(%39, %0) <{"operandSegmentSizes" = array, "static_offsets" = array, "static_sizes" = array, "static_strides" = array}> : (memref<3x64xf32>, index) -> memref<32xf32, strided<[1], offset: ?>> + +// CHECK-NEXT: %32 = "csl.zeros"() : () -> memref<3x64xf32> +// CHECK-NEXT: %33 = arith.constant 3 : i16 +// CHECK-NEXT: %34 = arith.constant 64 : i16 +// CHECK-NEXT: %35 = "csl.get_mem_dsd"(%32, %33, %34) : (memref<3x64xf32>, i16, i16) -> !csl +// CHECK-NEXT: %36 = arith.constant 32 : i16 +// CHECK-NEXT: %37 = "csl.get_mem_dsd"(%32, %36) <{"tensor_access" = affine_map<(d0) -> (2, d0)>}> : (memref<3x64xf32>, i16) -> !csl +// CHECK-NEXT: %38 = arith.index_cast %0 : index to si16 +// CHECK-NEXT: %39 = "csl.increment_dsd_offset"(%37, %38) <{"elem_type" = f32}> : (!csl, si16) -> !csl + }) {sym_name = "program"} : () -> () } // CHECK-NEXT: }) {"sym_name" = "program"} : () -> () diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 3aa43841e6..d260dd6594 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -202,24 +202,25 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # ensure we send only core data assert isa(op.accumulator.type, memref.MemRefType[Attribute]) assert isa(op.field.type, memref.MemRefType[Attribute]) + # the accumulator might have additional dims when used for holding prefetched data + send_buf_shape = op.accumulator.type.get_shape()[ + -len(op.field.type.get_shape()) : + ] send_buf = memref.SubviewOp.get( op.field, [ (d - s) // 2 # symmetric offset - for s, d in zip( - op.accumulator.type.get_shape(), op.field.type.get_shape() - ) + for s, d in zip(send_buf_shape, op.field.type.get_shape(), strict=True) ], - op.accumulator.type.get_shape(), - len(op.accumulator.type.get_shape()) * [1], - op.accumulator.type, + send_buf_shape, + len(send_buf_shape) * [1], + memref.MemRefType(op.field.type.get_element_type(), send_buf_shape), ) # add api call num_chunks = arith.ConstantOp(IntegerAttr(op.num_chunks.value, i16)) chunk_ref = csl.AddressOfFnOp(chunk_fn) done_ref = csl.AddressOfFnOp(done_fn) - # send_buf = memref.Subview.get(op.field, [], op.accumulator.type.get_shape(), ) api_call = csl.MemberCallOp( "communicate", None, diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index a14b06e2b3..958ffd5e34 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -1,3 +1,4 @@ +import collections from collections.abc import Sequence from dataclasses import dataclass from typing import cast @@ -5,6 +6,8 @@ from xdsl.context import MLContext from xdsl.dialects import arith, builtin, csl, memref from xdsl.dialects.builtin import ( + AffineMapAttr, + AnyMemRefType, ArrayAttr, Float16Type, Float32Type, @@ -19,6 +22,7 @@ UnrealizedConversionCastOp, ) from xdsl.ir import Attribute, Operation, OpResult, SSAValue +from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineExpr, AffineMap from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -119,7 +123,47 @@ class LowerSubviewOpPass(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /): - assert isa(op.source.type, MemRefType[Attribute]) + assert isa(op.source.type, AnyMemRefType) + assert isa(op.result.type, AnyMemRefType) + + if len(op.result.type.get_shape()) == 1 and len(op.source.type.get_shape()) > 1: + # 1d subview onto a nd memref + sizes = op.static_sizes.get_values() + counter_sizes = collections.Counter(sizes) + counter_sizes.pop(1, None) + assert ( + len(counter_sizes) == 1 + ), "1d access into nd memref must specify one size > 1" + size, size_count = counter_sizes.most_common()[0] + size = cast(int, size) + + assert ( + size_count == 1 + ), "1d access into nd memref can only specify one size > 1, which can occur only once" + assert all( + stride == 1 for stride in op.static_strides.get_values() + ), "All strides must equal 1" + + amap: list[AffineExpr] = [ + AffineConstantExpr( + cast(int, o) if o != memref.SubviewOp.DYNAMIC_INDEX else 0 + ) + for o in op.static_offsets.get_values() + ] + amap[sizes.index(size)] += AffineDimExpr(0) + + size_op = arith.ConstantOp.from_int_and_width(size, 16) + dsd_op = csl.GetMemDsdOp( + operands=[op.source, [size_op]], + properties={ + "tensor_access": AffineMapAttr(AffineMap(1, 0, tuple(amap))) + }, + result_types=[csl.DsdType(csl.DsdKind.mem1d_dsd)], + ) + offset_ops = self._update_offsets(op, dsd_op) if op.offsets else [] + rewriter.replace_matched_op([size_op, dsd_op, *offset_ops]) + return + assert len(op.static_sizes) == 1, "not implemented" assert len(op.static_offsets) == 1, "not implemented" assert len(op.static_strides) == 1, "not implemented" @@ -219,7 +263,7 @@ def _update_offsets( static_offsets = cast(Sequence[int], subview.static_offsets.get_values()) - if static_offsets[0] == memref.SubviewOp.DYNAMIC_INDEX: + if subview.offsets: ops.append(cast_op := arith.IndexCastOp(subview.offsets[0], csl.i16_value)) ops.append( csl.IncrementDsdOffsetOp.build(