Skip to content

Commit

Permalink
[mlir][Linalg] NFC: Combine elementwise fusion test passes.
Browse files Browse the repository at this point in the history
There are a few different test passes that check elementwise fusion in
Linalg. Consolidate them to a single pass controlled by different pass
options (in keeping with how `TestLinalgTransforms` exists).
  • Loading branch information
Mahesh Ravishankar committed Feb 8, 2022
1 parent 5ebbcfa commit 2abd7f1
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 85 deletions.
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#binary2Dpointwise = {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s

// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s

func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
Expand Down
141 changes: 63 additions & 78 deletions mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ static bool setFusedOpOperandLimit(const OpResult &producer,
namespace {
struct TestLinalgElementwiseFusion
: public PassWrapper<TestLinalgElementwiseFusion, OperationPass<FuncOp>> {
TestLinalgElementwiseFusion() = default;
TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
tensor::TensorDialect>();
Expand All @@ -58,101 +61,83 @@ struct TestLinalgElementwiseFusion
return "Test Linalg element wise operation fusion patterns";
}

void runOnOperation() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getOperation();
RewritePatternSet fusionPatterns(context);
Option<bool> fuseGenericOps{
*this, "fuse-generic-ops",
llvm::cl::desc("Test fusion of generic operations."),
llvm::cl::init(false)};

linalg::populateElementwiseOpsFusionPatterns(
fusionPatterns,
linalg::LinalgElementwiseFusionOptions()
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
Option<bool> controlFuseByExpansion{
*this, "control-fusion-by-expansion",
llvm::cl::desc(
"Test controlling fusion of reshape with generic op by expansion"),
llvm::cl::init(false)};

(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
}
};

struct TestLinalgControlFuseByExpansion
: public PassWrapper<TestLinalgControlFuseByExpansion,
OperationPass<FuncOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-control-fusion-by-expansion";
}
StringRef getDescription() const final {
return "Test controlling of fusion of elementwise ops with reshape by "
"expansion";
}
Option<bool> pushExpandingReshape{
*this, "push-expanding-reshape",
llvm::cl::desc("Test linalg expand_shape -> generic "
"to generic -> expand_shape pattern"),
llvm::cl::init(false)};

void runOnOperation() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getOperation();
RewritePatternSet fusionPatterns(context);

linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
[](const OpResult &producer, OpOperand &consumer) {
if (auto collapseOp =
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
return false;

if (fuseGenericOps) {
RewritePatternSet fusionPatterns(context);
linalg::populateElementwiseOpsFusionPatterns(
fusionPatterns,
linalg::LinalgElementwiseFusionOptions()
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));

(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}

if (controlFuseByExpansion) {
RewritePatternSet fusionPatterns(context);

linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
[](const OpResult &producer, OpOperand &consumer) {
if (auto collapseOp =
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
return false;
}
}
}
if (auto expandOp =
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
if (expandOp->hasOneUse()) {
OpOperand &use = *expandOp->getUses().begin();
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (linalgOp && linalgOp.isOutputTensor(&use))
return true;
if (auto expandOp =
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
if (expandOp->hasOneUse()) {
OpOperand &use = *expandOp->getUses().begin();
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (linalgOp && linalgOp.isOutputTensor(&use))
return true;
}
}
}
return linalg::skipUnitDimReshape(producer, consumer);
};

linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return linalg::skipUnitDimReshape(producer, consumer);
};

linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}

if (pushExpandingReshape) {
RewritePatternSet patterns(context);
linalg::populatePushReshapeOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
}
};

struct TestPushExpandingReshape
: public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final { return "test-linalg-push-reshape"; }
StringRef getDescription() const final {
return "Test Linalg reshape push patterns";
}

void runOnOperation() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getOperation();
RewritePatternSet patterns(context);
linalg::populatePushReshapeOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};
} // namespace

namespace test {
void registerTestLinalgElementwiseFusion() {
PassRegistration<TestLinalgElementwiseFusion>();
}

void registerTestLinalgControlFuseByExpansion() {
PassRegistration<TestLinalgControlFuseByExpansion>();
}

void registerTestPushExpandingReshape() {
PassRegistration<TestPushExpandingReshape>();
}
} // namespace test

} // namespace mlir
4 changes: 0 additions & 4 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ void registerTestGenericIRVisitorsPass();
void registerTestGenericIRVisitorsInterruptPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgControlFuseByExpansion();
void registerTestLinalgDistribution();
void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
void registerTestLinalgTensorFusionTransforms();
void registerTestLinalgTiledLoopFusionTransforms();
Expand Down Expand Up @@ -172,10 +170,8 @@ void registerTestPasses() {
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestLinalgCodegenStrategy();
mlir::test::registerTestLinalgControlFuseByExpansion();
mlir::test::registerTestLinalgDistribution();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestPushExpandingReshape();
mlir::test::registerTestLinalgFusionTransforms();
mlir::test::registerTestLinalgTensorFusionTransforms();
mlir::test::registerTestLinalgTiledLoopFusionTransforms();
Expand Down

0 comments on commit 2abd7f1

Please sign in to comment.