Skip to content

Commit

Permalink
[FRONTEND] add experimental descriptor load mapping to TMA (triton-la…
Browse files Browse the repository at this point in the history
…ng#3739)

This adds an experimental escape hatch to access to TMA operations
explicitly through this front end. This is meant for experiments only as
it breaks portability. This will allow exercising TMA operations short
term and will be removed after that.

Users should not rely on this feature in their production code.
  • Loading branch information
ThomasRaoux authored Apr 25, 2024
1 parent 62706e8 commit 98c97cf
Show file tree
Hide file tree
Showing 14 changed files with 186 additions and 0 deletions.
32 changes: 32 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1090,4 +1090,36 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
let hasVerifier = 1;
}


def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
MemoryEffects<[MemRead<GlobalMemory>]>]> {
let summary = "Load from descriptor";
let description = [{
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
The destination tensor type and shape must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (
ins
TT_PtrType:$desc_ptr,
Variadic<I32>:$indices,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
);

let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$desc_ptr `[` $indices `]`
oilist(
`cacheModifier` `=` $cache |
`evictionPolicy` `=` $evict
)
attr-dict `:` qualified(type($desc_ptr)) `->` type($result)
}];
}

#endif // Triton_OPS
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
std::unique_ptr<Pass>
createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);

std::unique_ptr<Pass> createTritonNvidiaGPUTMALoweringPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
Expand Down
15 changes: 15 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,19 @@ def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::M
];
}


def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::ModuleOp"> {
let summary = "lower to TMA load/store operations";

let description = [{
Lower Triton experimental descriptor load to TMA load/store operations in TritonNvidiaGPUDialect.
}];

let constructor = "mlir::createTritonNvidiaGPUTMALoweringPass()";

let dependentDialects = [
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
}

#endif
1 change: 1 addition & 0 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::PrintOp>, GenericOpPattern<triton::AssertOp>,
GenericOpPattern<triton::AtomicCASOp>,
GenericOpPattern<triton::AtomicRMWOp>, GenericOpPattern<ReturnOp>,
GenericOpPattern<triton::ExperimentalDescriptorLoadOp>,
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
context);
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonNvidiaGPUTransforms
FenceInsertion.cpp
PlanCTA.cpp
TMALowering.cpp

DEPENDS
TritonNvidiaGPUTransformsIncGen
Expand Down
69 changes: 69 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

#include <memory>

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"

namespace {

using namespace mlir;
using namespace triton;
using namespace triton::gpu;
using namespace triton::nvidia_gpu;

class TMALoadLowering : public OpRewritePattern<ExperimentalDescriptorLoadOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto tensorType = op.getResult().getType();
auto order = getOrder(tensorType.getEncoding());
auto ctaLayout = getCTALayout(tensorType.getEncoding());
auto encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1,
order, ctaLayout);
MemDescType memDescType =
MemDescType::get(tensorType.getShape(), tensorType.getElementType(),
encoding, /*mutableMemory=*/true);
Value alloc = rewriter.create<LocalAllocOp>(loc, memDescType, Value());
MemDescType barrierMemDescType = MemDescType::get(
{1}, rewriter.getI64Type(), encoding, /*mutableMemory=*/true);
Value barrierAlloc =
rewriter.create<LocalAllocOp>(loc, barrierMemDescType, Value());
rewriter.create<InitBarrierOp>(loc, barrierAlloc, 1);

rewriter.create<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>(
loc, op.getDescPtr(), op.getIndices(), barrierAlloc, alloc);
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase);
rewriter.replaceOpWithNewOp<LocalLoadOp>(op, op.getType(), alloc);
return success();
}
};

class TritonNvidiaGPUTMALoweringPass
: public TritonNvidiaGPUTMALoweringPassBase<
TritonNvidiaGPUTMALoweringPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

mlir::RewritePatternSet patterns(context);
patterns.add<TMALoadLowering>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
};

} // namespace

std::unique_ptr<Pass> mlir::createTritonNvidiaGPUTMALoweringPass() {
return std::make_unique<TritonNvidiaGPUTMALoweringPass>();
}
8 changes: 8 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,14 @@ void init_triton_ir(py::module &&m) {
self.create<StoreOp>(ptrs, val, mask, cacheModifier,
evictionPolicy);
})
.def("create_descriptor_load",
[](TritonOpBuilder &self, Value &desc_ptr,
std::vector<Value> &indices, Type type,
CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> Value {
return self.create<ExperimentalDescriptorLoadOp>(
type, desc_ptr, indices, cacheModifier, evictionPolicy);
})
.def("create_reshape",
[](TritonOpBuilder &self, Value &arg, std::vector<int64_t> &shape,
bool allowReorder) -> Value {
Expand Down
24 changes: 24 additions & 0 deletions python/test/unit/hopper/test_experimental_tma.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile

import triton
import triton.language as tl


def test_descriptor_load_ttgir():
Expand Down Expand Up @@ -45,3 +46,26 @@ def test_descriptor_load_ttgir():
z_tri = torch.empty_like(x)
kernel[(1, 1, 1)](z_tri, desc)
assert torch.equal(x, z_tri)


def test_experimetal_descriptor_load():
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9:
pytest.skip("Test requires Hopper target.")
return
device = "cuda"
SIZE = 128

@triton.jit
def kernel(Z, desc, SIZE: tl.constexpr):
off_desc = 0
off = tl.arange(0, SIZE)
x = tl._experimental_descriptor_load(desc, [off_desc], [SIZE], Z.dtype)
tl.store(Z + off, x)

x = torch.randn(SIZE, dtype=torch.float32, device=device)
desc = np.empty(SIZE, dtype=np.int8)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(x.data_ptr(), SIZE, 4, desc)
desc = torch.tensor(desc, device=device)
z_tri = torch.empty_like(x)
kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4)
assert torch.equal(x, z_tri)
2 changes: 2 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TRITON_MAX_TENSOR_NUMEL,
_experimental_join,
_experimental_split,
_experimental_descriptor_load,
advance,
arange,
associative_scan,
Expand Down Expand Up @@ -122,6 +123,7 @@
"TRITON_MAX_TENSOR_NUMEL",
"_experimental_join",
"_experimental_split",
"_experimental_descriptor_load",
"abs",
"advance",
"arange",
Expand Down
12 changes: 12 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,18 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
volatile, _builder)


@builtin
def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None):
"""
Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations.
This will be removed in the future and shouldn't be used in production code.
This loads a tensor of data based on the descriptor and offsets.
"""
type = block_type(dtype.element_ty, shape)
return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder)


@_tensor_member_fn
@builtin
def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None):
Expand Down
9 changes: 9 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,15 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor],
return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)


def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type,
builder: ir.builder) -> tl.tensor:
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder),
_str_to_load_cache_modifier(cache_modifier),
_str_to_eviction_policy(eviction_policy))
return tl.tensor(x, type)


def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
# Store by a block pointer: `pointer_type<block_type<>>`
# Block pointers can not have the `mask` argument
Expand Down
8 changes: 8 additions & 0 deletions test/Triton/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,11 @@ tt.func @histogram(%0: tensor<512xi32>) {
%1 = tt.histogram %0 : tensor<512xi32> -> tensor<16xi32>
tt.return
}

// CHECK-LABEL: experimental_descriptor_load
tt.func @experimental_descriptor_load(%0: !tt.ptr<i8>) {
// CHECK: tt.experimental_descriptor_load %{{.+}}[%{{.+}}] : !tt.ptr<i8> -> tensor<128xf32>
%c0_i32 = arith.constant 0 : i32
%1 = tt.experimental_descriptor_load %0[%c0_i32] : !tt.ptr<i8> -> tensor<128xf32>
tt.return
}
1 change: 1 addition & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def make_ttgir(mod, metadata, opt, capability):
passes.common.add_symbol_dce(pm)
if capability // 10 >= 9:
nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
passes.common.add_canonicalizer(pm)
pm.run(mod)
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
Expand Down
2 changes: 2 additions & 0 deletions third_party/nvidia/triton_nvidia.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) {
mlir::triton::nvidia_gpu::ClusterInfo *);
ADD_PASS_WRAPPER_0("add_fence_insertion",
mlir::createTritonNvidiaGPUFenceInsertionPass);
ADD_PASS_WRAPPER_0("add_tma_lowering",
mlir::createTritonNvidiaGPUTMALoweringPass);
ADD_PASS_WRAPPER_0("add_nvgpu_to_llvm",
mlir::triton::createConvertNVGPUToLLVMPass);
}
Expand Down

0 comments on commit 98c97cf

Please sign in to comment.