Skip to content

Commit

Permalink
porting linalg transformations to new mlir (tiling and microkernel) p…
Browse files Browse the repository at this point in the history
  • Loading branch information
gkestor authored and pthomadakis committed Oct 22, 2023
1 parent a9d649e commit 4e2928d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
8 changes: 4 additions & 4 deletions frontends/comet_dsl/comet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,10 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
// optPM.addPass(mlir::tensorAlgebra::createOptDenseTransposePass());
}

// if (OptMatmulTiling)
// {
// optPM.addPass(mlir::tensorAlgebra::createLinAlgMatmulTilingPass());
// }
if (OptMatmulTiling)
{
optPM.addPass(mlir::comet::createLinAlgMatmulTilingPass());
}

// if (OptCallToMatMulMicroKernel)
// {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TensorAlgebra/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ add_llvm_library(COMETTensorAlgebraDialect
IR/TADialect.cpp

Transforms/Transforms.cpp
# Transforms/LinalgTransforms.cpp
Transforms/LinalgTransforms.cpp
Transforms/TCtoTTGT.cpp
Transforms/Passes.cpp

Expand Down
57 changes: 46 additions & 11 deletions lib/Dialect/TensorAlgebra/Transforms/LinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
// #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -45,9 +44,9 @@ using namespace mlir::arith;
using namespace mlir::tensorAlgebra;

// *********** For debug purpose *********//
// #ifndef DEBUG_MODE_LINALGTRANSFORMS
// #define DEBUG_MODE_LINALGTRANSFORMS
// #endif
#ifndef DEBUG_MODE_LINALGTRANSFORMS
#define DEBUG_MODE_LINALGTRANSFORMS
#endif

#ifdef DEBUG_MODE_LINALGTRANSFORMS
#define comet_debug() llvm::errs() << __FILE__ << " " << __LINE__ << " "
Expand All @@ -58,7 +57,12 @@ using namespace mlir::tensorAlgebra;
llvm::errs() << __FILE__ << " " << __LINE__ << " "; \
n.dump()
#else
#define comet_debug() if(true){}else llvm::errs()
#define comet_debug() \
if (true) \
{ \
} \
else \
llvm::errs()
#define comet_pdump(n)
#define comet_vdump(n)
#endif
Expand All @@ -72,10 +76,43 @@ namespace
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinAlgMatmulTilingPass)
void runOnOperation() override
{
func::FuncOp func = getOperation();
MLIRContext *ctx = func.getContext();

RewritePatternSet patterns(&getContext());
//===----------------------------------------------------------------------===//
// BLIS HASWELL
//===----------------------------------------------------------------------===//
// #define BLIS_DGEMM_UKERNEL bli_dgemm_asm_8x6
// #define BLIS_DEFAULT_MC_D 72
// #define BLIS_DEFAULT_KC_D 256
// #define BLIS_DEFAULT_NC_D 4080
// #define BLIS_DEFAULT_MR_D 8
// #define BLIS_DEFAULT_NR_D 6

ArrayRef<int64_t> tileInterchange_L2;


// Tile the root operation.
LinalgTilingOptions tilingOptions;
tilingOptions = tilingOptions
// .setInterchange(SmallVector<unsigned>(
// tileInterchange.begin(), tileInterchange.end()))
.setInterchange({1, 2, 0})
//.setTileSizes(tileSizes)
.setTileSizes({72, 4080, 256})
.setLoopType(LinalgTilingLoopType::Loops);

// TODO: Propagate RewriterBase everywhere.
IRRewriter rewriter(b);
FailureOr<TiledLinalgOp> tiledRootOp =
tileLinalgOp(rewriter, rootOp, tilingOptions);

// Exit if tiling the root operation fails.
// if (failed(tiledRootOp))
// return failure();

// func::FuncOp func = getOperation();
// MLIRContext *ctx = func.getContext();

// RewritePatternSet patterns(&getContext());

// Add the matmul tiling patterns to the list.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -107,7 +144,7 @@ namespace
// LinalgTransformationFilter(Identifier::get("L2__with_tiling__", ctx),
// Identifier::get("__micro_kernel__", ctx)));

(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
//(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};
} // end anonymous namespace
Expand Down Expand Up @@ -467,7 +504,6 @@ namespace
// };
// } // end anonymous namespace


/// Create a pass to optimize LinAlg Matmul Op with tiling
std::unique_ptr<mlir::Pass> mlir::comet::createLinAlgMatmulTilingPass()
{
Expand All @@ -489,4 +525,3 @@ std::unique_ptr<mlir::Pass> mlir::comet::createLinAlgMatmulMicroKernelPass()
// comet_debug() << "LinAlgTransforms createOptDenseTransposePass\n";
// return std::make_unique<OptDenseTransposePass>(tile_size, seperate_tiles);
// }

0 comments on commit 4e2928d

Please sign in to comment.