Skip to content

Commit

Permalink
Fix for OSS issue 4060
Browse files Browse the repository at this point in the history
  • Loading branch information
Karthik Manivannan committed Jul 8, 2024
1 parent 0430b9e commit 6b235b0
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 4 deletions.
16 changes: 14 additions & 2 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,20 @@ void init_triton_llvm(py::module &&m) {

m.def(
"to_module",
[](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) {
return mlir::translateModuleToLLVMIR(mod, ctx);
[](mlir::ModuleOp &mod, llvm::LLVMContext &ctx, const std::string triple,
std::string proc, std::string features) {
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(triple, error);
llvm::TargetOptions opt;
std::unique_ptr<llvm::TargetMachine> machine{
target->createTargetMachine(triple, proc, features, opt,
llvm::Reloc::PIC_, std::nullopt,
llvm::CodeGenOptLevel::None)};
auto dl = machine->createDataLayout();
mod->setAttr(mlir::LLVM::LLVMDialect::getDataLayoutAttrName(),
mlir::StringAttr::get(mod.getContext(),
dl.getStringRepresentation()));
return mlir::translateModuleToLLVMIR(mod, ctx, triple);
},
py::keep_alive<0, 2>());

Expand Down
22 changes: 22 additions & 0 deletions python/test/unit/runtime/test_i64_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import triton
import triton.language as tl
import torch


def test_i64_printf(capfd):

@triton.jit
def ndscore_kernel(ptr):
value = tl.load(ptr)
print("value in kernel", value)
tl.store(ptr, value + 1)

ptr = torch.tensor(42, dtype=torch.int64).cuda()
print("value before kernel", ptr.item())
kernel = ndscore_kernel[(1, )](ptr)
kernel
print("value after kernel", ptr.item())
captured = capfd.readouterr()
assert "value in kernel: 42" in captured.out
assert "value before kernel 42" in captured.out
assert "value after kernel 43" in captured.out
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def make_llir(src, metadata, options):
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
llvm_mod = llvm.to_module(mod, context)
llvm_mod = llvm.to_module(mod, context, amd.TARGET_TRIPLE, options.arch, '')

# Set various control constants on the LLVM module so that device
# libraries can resolve references to them.
Expand Down
16 changes: 15 additions & 1 deletion third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,21 @@ def make_llir(src, metadata, options, capability):
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
llvm_mod = llvm.to_module(mod, context)
proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
ptx_version = options.ptx_version
if ptx_version is None:
_, cuda_version = _path_to_binary("ptxas")
ptx_version = ptx_get_version(cuda_version)

# PTX 8.3 is the max version supported by llvm 3a83162168.
#
# To check if a newer PTX version is supported, increase this value
# and run a test. If it's not supported, LLVM will print a warning
# like "+ptx8.4 is not a recognized feature for this target".
llvm_ptx_version = min(83, ptx_version)
triple = 'nvptx64-nvidia-cuda'
features = f'+ptx{llvm_ptx_version}'
llvm_mod = llvm.to_module(mod, context, triple, proc, features)
nvidia.set_nvvm_reflect_ftz(llvm_mod)

# Set maxnreg on all kernels, if it was provided.
Expand Down

0 comments on commit 6b235b0

Please sign in to comment.