forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FRONTEND] add experimental descriptor load mapping to TMA (triton-la…
…ng#3739) This adds an experimental escape hatch to access to TMA operations explicitly through this front end. This is meant for experiments only as it breaks portability. This will allow exercising TMA operations short term and will be removed after that. Users should not rely on this feature in their production code.
- Loading branch information
1 parent
62706e8
commit 98c97cf
Showing
14 changed files
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" | ||
|
||
#include <memory> | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
using namespace mlir; | ||
using namespace triton; | ||
using namespace triton::gpu; | ||
using namespace triton::nvidia_gpu; | ||
|
||
class TMALoadLowering : public OpRewritePattern<ExperimentalDescriptorLoadOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op, | ||
PatternRewriter &rewriter) const override { | ||
auto loc = op.getLoc(); | ||
auto tensorType = op.getResult().getType(); | ||
auto order = getOrder(tensorType.getEncoding()); | ||
auto ctaLayout = getCTALayout(tensorType.getEncoding()); | ||
auto encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, | ||
order, ctaLayout); | ||
MemDescType memDescType = | ||
MemDescType::get(tensorType.getShape(), tensorType.getElementType(), | ||
encoding, /*mutableMemory=*/true); | ||
Value alloc = rewriter.create<LocalAllocOp>(loc, memDescType, Value()); | ||
MemDescType barrierMemDescType = MemDescType::get( | ||
{1}, rewriter.getI64Type(), encoding, /*mutableMemory=*/true); | ||
Value barrierAlloc = | ||
rewriter.create<LocalAllocOp>(loc, barrierMemDescType, Value()); | ||
rewriter.create<InitBarrierOp>(loc, barrierAlloc, 1); | ||
|
||
rewriter.create<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>( | ||
loc, op.getDescPtr(), op.getIndices(), barrierAlloc, alloc); | ||
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32); | ||
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase); | ||
rewriter.replaceOpWithNewOp<LocalLoadOp>(op, op.getType(), alloc); | ||
return success(); | ||
} | ||
}; | ||
|
||
class TritonNvidiaGPUTMALoweringPass | ||
: public TritonNvidiaGPUTMALoweringPassBase< | ||
TritonNvidiaGPUTMALoweringPass> { | ||
public: | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ModuleOp m = getOperation(); | ||
|
||
mlir::RewritePatternSet patterns(context); | ||
patterns.add<TMALoadLowering>(context); | ||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) | ||
signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> mlir::createTritonNvidiaGPUTMALoweringPass() { | ||
return std::make_unique<TritonNvidiaGPUTMALoweringPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters