Skip to content

Commit

Permalink
Cleanup: Use upstream TransformInterpreterPassBase (iree-org#13633)
Browse files Browse the repository at this point in the history
Transform dialect interpreter passes must still be defined in IREE, but
they can use the upstream
`mlir::transform::TransformInterpreterPassBase` implementation.
  • Loading branch information
matthias-springer authored May 16, 2023
1 parent 9bf6fb1 commit ebde1b1
Show file tree
Hide file tree
Showing 17 changed files with 41 additions and 660 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:VectorDialect",
# IR
"@llvm-project//mlir:Analysis",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ iree_cc_library(
MLIRTensorDialect
MLIRTensorTransforms
MLIRTransformDialect
MLIRTransformDialectTransforms
MLIRVectorDialect
MLIRVectorTransformOps
MLIRVectorTransforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h"
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
#include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h"
#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
Expand Down Expand Up @@ -40,6 +39,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
Expand All @@ -54,7 +54,7 @@ namespace {
/// This needs to be its own pass because the registration mechanism and ops
/// available are different than for other interpreters.
class TransformDialectInterpreterPass
: public transform::iree_dialects::TransformInterpreterPassBase<
: public mlir::transform::TransformInterpreterPassBase<
TransformDialectInterpreterPass,
iree_compiler::TransformDialectInterpreterBase> {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ module {

// -----

// expected-error @below {{transform dialect interpreter failed}}
module {
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
Expand Down Expand Up @@ -107,7 +106,6 @@ module {

// -----

// expected-error @below {{transform dialect interpreter failed}}
module {
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: iree-opt %s --split-input-file --iree-transform-dialect-interpreter --verify-diagnostics

module {
// expected-error @above {{transform dialect interpreter failed}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{match registry not available}}
Expand All @@ -12,7 +11,6 @@ module {
// -----

module {
// expected-error @above {{transform dialect interpreter failed}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.iree.register_match_callbacks
Expand All @@ -24,7 +22,6 @@ module {
// -----

module {
// expected-error @above {{transform dialect interpreter failed}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.iree.register_match_callbacks
Expand All @@ -47,7 +44,6 @@ module {
// -----

module attributes {test.iree_transform_do_not_match} {
// expected-error @above {{transform dialect interpreter failed}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
transform.iree.register_match_callbacks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ builtin.module {
// Check that we reject constructs that try to apply create_async_groups
// on non-func op.

// expected-error@below {{transform dialect interpreter failed}}
builtin.module {
func.func @copies_to_asyncs_invalid_op_input(%a: memref<1024x1024xf32>) {
// expected-note@below {{when applied to this op}}
Expand Down
11 changes: 9 additions & 2 deletions compiler/src/iree/compiler/Codegen/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -393,17 +393,24 @@ def TransformDialectInterpreter :
"Optional filename containing a transform dialect specification to "
"apply. If left empty, the IR is assumed to contain one top-level "
"transform dialect operation somewhere in the module.">,
Option<"transformLibraryFileName",
"transform-library-file-name",
"std::string",
/*default=*/"\"\"",
"If non-empty, the name of the file containing definitions of "
"external symbols referenced in the transform script. "
"These definitions will be used to replace declarations.">,
Option<"debugPayloadRootTag", "debug-payload-root-tag", "std::string",
/*default=*/"\"\"",
"Select the operation with 'transform.iree_tag' attribute having "
"Select the operation with 'transform.target_tag' attribute having "
"the given value as payload IR root. This allows user control on "
"what operation to transform in debug mode, without requiring "
"intimate knowledge of the IREE nested pass pipeline.\\n"
"If empty (normal operation mode), select the pass anchor "
"operation in the IREE pipeline, as the payload IR root.">,
Option<"debugTransformRootTag", "debug-transform-root-tag", "std::string",
/*default=*/"\"\"",
"Select the operation with 'transform.iree_tag' attribute having "
"Select the operation with 'transform.target_tag' attribute having "
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply in debug "
"mode, without requiring intimate knowledge of the IREE nested "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ iree_compiler_cc_library(
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgExtUtils",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialectPasses",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:Analysis",
Expand Down Expand Up @@ -121,6 +120,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:TilingInterface",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ iree_cc_library(
IREELinalgExtTransforms
IREELinalgExtUtils
IREELinalgTransformDialect
IREELinalgTransformDialectPasses
LLVMSupport
MLIRAffineDialect
MLIRAnalysis
Expand All @@ -102,6 +101,7 @@ iree_cc_library(
MLIRTilingInterface
MLIRTosaDialect
MLIRTransformDialect
MLIRTransformDialectTransforms
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::Flow::Conversion::TensorToFlow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterPassBase.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
Expand All @@ -18,6 +17,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"

Expand All @@ -30,8 +30,8 @@ namespace Flow {
/// Interpreter pass that applies transform dialect ops for dispatch region
/// formation. This needs to be its own pass because the registration mechanism
/// and ops available are different than for other interpreters.
struct DispatchWithTransformDialect
: public transform::iree_dialects::TransformInterpreterPassBase<
class DispatchWithTransformDialect
: public mlir::transform::TransformInterpreterPassBase<
DispatchWithTransformDialect, DispatchWithTransformDialectBase> {
void getDependentDialects(DialectRegistry &registry) const override {
// clang-format off
Expand All @@ -49,6 +49,7 @@ struct DispatchWithTransformDialect
// clang-format on
}

public:
DispatchWithTransformDialect(StringRef transformFileName,
StringRef debugPayloadRootTag = StringRef(),
StringRef debugTransformRootTag = StringRef()) {
Expand Down
11 changes: 9 additions & 2 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,24 @@ def DispatchWithTransformDialect :
"Optional filename containing a transform dialect specification to "
"apply. If left empty, the IR is assumed to contain one top-level "
"transform dialect operation somewhere in the module.">,
Option<"transformLibraryFileName",
"transform-library-file-name",
"std::string",
/*default=*/"\"\"",
"If non-empty, the name of the file containing definitions of "
"external symbols referenced in the transform script. "
"These definitions will be used to replace declarations.">,
Option<"debugPayloadRootTag", "debug-payload-root-tag", "std::string",
/*default=*/"\"\"",
"Select the operation with 'transform.iree_tag' attribute having "
"Select the operation with 'transform.target_tag' attribute having "
"the given value as payload IR root. This allows user control on "
"what operation to transform in debug mode, without requiring "
"intimate knowledge of the IREE nested pass pipeline.\\n"
"If empty (normal operation mode), select the pass anchor "
"operation in the IREE pipeline, as the payload IR root.">,
Option<"debugTransformRootTag", "debug-transform-root-tag", "std::string",
/*default=*/"\"\"",
"Select the operation with 'transform.iree_tag' attribute having "
"Select the operation with 'transform.target_tag' attribute having "
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply in debug "
"mode, without requiring intimate knowledge of the IREE nested "
Expand Down
1 change: 1 addition & 0 deletions llvm-external-projects/iree-dialects/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ cc_library(
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
add_mlir_library(IREELinalgTransformDialectPasses
ExpertExpansion.cpp
TransformInterpreter.cpp
TransformInterpreterPassBase.cpp

DEPENDS
mlir-headers
Expand All @@ -19,6 +18,7 @@ add_mlir_library(IREELinalgTransformDialectPasses
MLIRMemRefToLLVM
MLIRPass
MLIRTensorDialect
MLIRTransformDialectTransforms
MLIRTransforms
MLIRVectorDialect
MLIRVectorToLLVM
Expand Down
Loading

0 comments on commit ebde1b1

Please sign in to comment.