Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Update LLVM to llvm/llvm-project@33da608
* Cherry-picked llvm/llvm-project@aa90948
* Updated HLO to tensorflow/mlir-hlo@65eb2d4
  (includes CMake fixes after integrating llvm/llvm-project@33da608)
* Fix transform dialect related tests using iree-org#13635.
* Fix `TransformState::getPayloadOps iree-org#13621.
* Fix tests after upstream changes hoisting more of
  `vector.transfer_read`.

Closes iree-org#13635 
Closes iree-org#13621
Should address iree-org#13419 (pending confirmation)

---------

Co-authored-by: Alex Zinenko <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>
Co-authored-by: Hanhan Wang <[email protected]>
  • Loading branch information
4 people authored May 24, 2023
1 parent d103c86 commit 2787dd5
Show file tree
Hide file tree
Showing 57 changed files with 267 additions and 214 deletions.
2 changes: 1 addition & 1 deletion build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, repo_map: Dict[str, str]):
"@llvm-project//mlir:AllPassesAndDialects": ["MLIRAllDialects"],
"@llvm-project//mlir:CommonFolders": [""],
"@llvm-project//mlir:DialectUtils": [""],
"@llvm-project//mlir:GPUDialect": ["MLIRGPUOps"],
"@llvm-project//mlir:GPUDialect": ["MLIRGPUDialect"],
"@llvm-project//mlir:GPUTransforms": ["MLIRGPUTransforms"],
"@llvm-project//mlir:LinalgStructuredOpsIncGen": [
"MLIRLinalgStructuredOpsIncGenLib"
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ iree_cc_library(
MLIRBufferizationTransformOps
MLIRBufferizationTransforms
MLIRFuncDialect
MLIRGPUOps
MLIRGPUDialect
MLIRGPUTransformOps
MLIRIR
MLIRLLVMDialect
Expand Down Expand Up @@ -132,7 +132,7 @@ iree_cc_library(
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRFuncDialect
MLIRGPUOps
MLIRGPUDialect
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
Expand Down Expand Up @@ -209,7 +209,7 @@ iree_cc_library(
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRFuncDialect
MLIRGPUOps
MLIRGPUDialect
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ iree_cc_library(
MLIRArithDialect
MLIRBufferizationDialect
MLIRFuncDialect
MLIRGPUOps
MLIRGPUDialect
MLIRGPUTransformOps
MLIRGPUTransforms
MLIRIR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ iree_cc_library(
MLIRArithUtils
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRGPUOps
MLIRGPUDialect
MLIRIR
MLIRLinalgDialect
MLIRLinalgTransformOps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1003,10 +1003,10 @@ struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> {

DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
if (payload.size() != 1 ||
auto payload = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(payload) ||
!isa<ModuleOp, HAL::ExecutableOp, HAL::ExecutableVariantOp>(
payload.front())) {
*payload.begin())) {
return mlir::emitDefiniteFailure(
state.getTopLevel(),
"requires exactly a single HAL::ExecutableOp or "
Expand All @@ -1031,7 +1031,7 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply(
memCpyFn = gpuComprehensiveBufferizeCopyFn;
}

Operation *target = payload.front();
Operation *target = *payload.begin();
Location loc = target->getLoc();
ErrorCheckingTrackingListener listener(state, *this);
// 1. Rewrite tensor.empty to tensor.alloc, without the pass baggage.
Expand Down Expand Up @@ -1073,7 +1073,7 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply(

// Early exit if test_analysis_only is set.
if (getTestAnalysisOnly()) {
results.set(getOperation()->getOpResult(0), payload.front());
results.set(getOperation()->getOpResult(0), {*payload.begin()});
return listener.check(loc);
}

Expand All @@ -1093,7 +1093,7 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply(
if (res.wasInterrupted())
return listener.check(loc, emitDefaultDefiniteFailure(target));

results.set(getOperation()->getOpResult(0), payload.front());
results.set(getOperation()->getOpResult(0), {*payload.begin()});
return listener.check(loc);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp :
to different workgroups.
}];

let arguments = (ins PDL_Operation:$for_all_op);
let arguments = (ins TransformHandleTypeInterface:$for_all_op);
let results = (outs);
let assemblyFormat = [{
attr-dict $for_all_op `:` functional-type($for_all_op, results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ transform.sequence failures(propagate) {

%_, %more_parallel_fill, %parallel_reduction, %combiner_op =
transform.structured.split_reduction %reduction { split_factor = 2, insert_split_dimension = 1 }
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)

// Step 1. Map to a single block by tiling with size 1 and fusing.
%fusion_root_1, %fusion_group_1 = transform.iree.take_first %maybe_trailing_0, %combiner_op
: (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
%grid_loop, %outer_tiled = transform.structured.tile_to_forall_op %fusion_root_1 tile_sizes [1]
( mapping = [#gpu.block<x>] )
: (!pdl.operation) -> (!pdl.operation, !pdl.operation)

%func = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
transform.iree.apply_patterns %func { bubble_expand } : (!pdl.operation) -> ()
Expand All @@ -27,11 +29,11 @@ transform.sequence failures(propagate) {
transform.sequence %arg0 : !pdl.operation -> !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation
failures(propagate) {
^bb1(%arg1: !pdl.operation):
%fused_22 = transform.structured.fuse_into_containing_op %fusion_group_1 into %grid_loop
%parallel_reduction_22 = transform.structured.fuse_into_containing_op %parallel_reduction into %grid_loop
%more_parallel_fill_22 = transform.structured.fuse_into_containing_op %more_parallel_fill into %grid_loop
%original_fill_22 = transform.structured.fuse_into_containing_op %original_fill into %grid_loop
%maybe_leading_22 = transform.structured.fuse_into_containing_op %maybe_leading into %grid_loop
%fused_22 = transform.structured.fuse_into_containing_op %fusion_group_1 into %grid_loop : (!pdl.operation, !pdl.operation) -> !pdl.operation
%parallel_reduction_22 = transform.structured.fuse_into_containing_op %parallel_reduction into %grid_loop : (!pdl.operation, !pdl.operation) -> !pdl.operation
%more_parallel_fill_22 = transform.structured.fuse_into_containing_op %more_parallel_fill into %grid_loop : (!pdl.operation, !pdl.operation) -> !pdl.operation
%original_fill_22 = transform.structured.fuse_into_containing_op %original_fill into %grid_loop : (!pdl.operation, !pdl.operation) -> !pdl.operation
%maybe_leading_22 = transform.structured.fuse_into_containing_op %maybe_leading into %grid_loop : (!pdl.operation, !pdl.operation) -> !pdl.operation

transform.yield %fused_22, %parallel_reduction_22, %more_parallel_fill_22, %original_fill_22, %maybe_leading_22
: !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation
Expand All @@ -44,14 +46,17 @@ transform.sequence failures(propagate) {
%block_loop_22, %fusion_root_22_tiled =
transform.structured.tile_to_forall_op %outer_tiled
tile_sizes [1] ( mapping = [#gpu.thread<z>] )
transform.structured.fuse_into_containing_op %fusion_group_22_full into %block_loop_22
: (!pdl.operation) -> (!pdl.operation, !pdl.operation)
transform.structured.fuse_into_containing_op %fusion_group_22_full into %block_loop_22 : (!pdl.operation, !pdl.operation) -> !pdl.operation


%fusion_group_21 = transform.merge_handles %maybe_leading_2, %more_parallel_fill_2
: !pdl.operation
%block_loop_21, %fusion_root_21_tiled =
transform.structured.tile_to_forall_op %parallel_reduction_2
tile_sizes [1, 1] ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
transform.structured.fuse_into_containing_op %fusion_group_21 into %block_loop_21
: (!pdl.operation) -> (!pdl.operation, !pdl.operation)
transform.structured.fuse_into_containing_op %fusion_group_21 into %block_loop_21 : (!pdl.operation, !pdl.operation) -> !pdl.operation

// Step 3. Rank-reduce.
// ===========================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func.func @ukernel_generic_non_tensor_memref_outs(

func.func @ukernel_generic_err_tensor_outs(
%out0: tensor<?xf32>, %out1 : memref<?x?xf32>) {
// expected-error @+1 {{expected the number of results (0) to be equal to the number of output tensors (1)}}
// expected-error @+1 {{expected the number of tensor results (0) to be equal to the number of output tensors (1)}}
iree_codegen.ukernel.generic "foo"
outs(%out0, %out1 : tensor<?xf32>, memref<?x?xf32>)
}
Expand All @@ -132,7 +132,7 @@ func.func @ukernel_generic_mixed_tensor_memref(
func.func @ukernel_generic_err_memref_outs(
%out0: tensor<?xf32>, %out1 : memref<?x?xf32>)
-> (tensor<?xf32>, tensor<?x?xf32>){
// expected-error @+1 {{expected the number of results (2) to be equal to the number of output tensors (1)}}
// expected-error @+1 {{expected the number of tensor results (2) to be equal to the number of output tensors (1)}}
%0:2 = iree_codegen.ukernel.generic "foo"
outs(%out0, %out1 : tensor<?xf32>, memref<?x?xf32>) -> tensor<?xf32>, tensor<?x?xf32>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ iree_cc_library(
"ProcessorOpInterfaces.cpp"
DEPS
::ProcessorOpInterfaceGen
MLIRGPUOps
MLIRGPUDialect
MLIRIR
iree::compiler::Dialect::HAL::IR
PUBLIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ transform.sequence failures(propagate) {
%forall, %matmul =
transform.structured.tile_to_forall_op %original_matmul num_threads [32]
( mapping = [#gpu.block<x>] )
: (!pdl.operation) -> (!pdl.operation, !pdl.operation)

// Late canonicalizations to cleanup and pass the checks.
// Needs to occur on the whole variant to perform cse on the workgroup_count region
Expand Down Expand Up @@ -112,7 +113,7 @@ hal.executable private @matmul_static_dispatch_0 {
transform.sequence failures(propagate) {
^bb1(%variant_op: !pdl.operation):
%1 = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%forall_op, %tiled_op = transform.structured.tile_to_forall_op %1 num_threads [] tile_sizes [1, 1, 1](mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>])
%forall_op, %tiled_op = transform.structured.tile_to_forall_op %1 num_threads [] tile_sizes [1, 1, 1](mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>]): (!pdl.operation) -> (!pdl.operation, !pdl.operation)
transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_op : (!pdl.operation) -> ()
}

Expand Down Expand Up @@ -162,6 +163,6 @@ hal.executable private @matmul_static_dispatch_0 {
transform.sequence failures(propagate) {
^bb1(%variant_op: !pdl.operation):
%1 = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%forall_op, %tiled_op = transform.structured.tile_to_forall_op %1 num_threads [] tile_sizes [5, 3](mapping = [#gpu.block<z>, #gpu.block<x>])
%forall_op, %tiled_op = transform.structured.tile_to_forall_op %1 num_threads [] tile_sizes [5, 3](mapping = [#gpu.block<z>, #gpu.block<x>]) : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_op : (!pdl.operation) -> ()
}
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ iree_cc_library(
MLIRFuncDialect
MLIRFuncToLLVM
MLIRFuncTransforms
MLIRGPUOps
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransformOps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ iree_cc_library(
MLIRArithDialect
MLIRBufferizationDialect
MLIRFuncDialect
MLIRGPUOps
MLIRGPUDialect
MLIRGPUTransformOps
MLIRIR
MLIRLinalgTransforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ def LayoutAnalysisAndDistributionOp :

}];

let arguments = (ins PDL_Operation:$target);
let results = (outs Variadic<PDL_Operation>:$result);
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs Variadic<TransformHandleTypeInterface>:$result);

let assemblyFormat = [{ $target attr-dict `:` functional-type(operands, results)}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ iree_cc_library(
MLIRAffineDialect
MLIRArithDialect
MLIRFuncDialect
MLIRGPUOps
MLIRGPUDialect
MLIRIR
MLIRMathDialect
MLIRMemRefDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ transform.sequence failures(propagate) {
%forall_grid, %tiled_attention =
transform.structured.tile_to_forall_op %attention tile_sizes [1, 128]
( mapping = [#gpu.block<x>, #gpu.block<y>] )
: (!pdl.operation) -> (!pdl.operation, !pdl.operation)

// Tile and decompose attention
// ==========================================
Expand All @@ -53,7 +54,7 @@ transform.sequence failures(propagate) {
// ==========================================
%func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation
transform.iree.apply_patterns %func { rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
%func_3 = transform.structured.vectorize %func
%func_3 = transform.structured.vectorize %func : (!pdl.operation) -> !pdl.operation

// Bufferization
// ==========================================
Expand Down Expand Up @@ -108,13 +109,13 @@ transform.sequence failures(propagate) {
// CHECK-DAG: %[[D4:.+]] = affine.apply #[[MAP]]()[%[[WORKGROUP_ID_Y]]]
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[D3]][%[[WORKGROUP_ID_X]], %[[D4]], 0] [1, 128, 64] [1, 1, 1]
// CHECK-SAME: : memref<192x1024x64xf32> to memref<1x128x64xf32, strided<[65536, 64, 1], offset: ?>>
// CHECK: %[[D7:.+]] = vector.transfer_read %[[D0]][%[[WORKGROUP_ID_X]], %[[D4]], %[[C0]]], %[[CST_2]]
// CHECK-SAME: {in_bounds = [true, true]} : memref<192x1024x64xf32>, vector<128x64xf32>
// CHECK: %[[D5:.+]] = vector.transfer_read %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]], %[[CST_2]] {in_bounds
// CHECK-SAME: = [true, true]} : memref<1x128x64xf32, strided<[65536, 64, 1], offset: ?>>, vector<128x64xf32>
// CHECK: %[[D6:.+]]:3 = scf.for %[[ARG0:.+]] = %[[C0]] to %[[C1024]] step %[[C128]]
// CHECK-SAME: iter_args(%[[ARG1:.+]] = %[[CST]], %[[ARG2:.+]] = %[[CST_0]], %[[ARG3:.+]] = %[[D5]]) -> (vector<128xf32>,
// CHECK-SAME: vector<128xf32>, vector<128x64xf32>) {
// CHECK: %[[D7:.+]] = vector.transfer_read %[[D0]][%[[WORKGROUP_ID_X]], %[[D4]], %[[C0]]], %[[CST_2]]
// CHECK-SAME: {in_bounds = [true, true]} : memref<192x1024x64xf32>, vector<128x64xf32>
// CHECK: %[[D8:.+]] = vector.transfer_read %[[D1]][%[[WORKGROUP_ID_X]], %[[ARG0]], %[[C0]]], %[[CST_2]]
// CHECK-SAME: {in_bounds = [true, true]} : memref<192x1024x64xf32>, vector<128x64xf32>
// CHECK: %[[D9:.+]] = vector.contract {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iterator_types
Expand Down
Loading

0 comments on commit 2787dd5

Please sign in to comment.