Skip to content

Commit

Permalink
[BACKEND] A general interface for initializing destination operands i…
Browse files Browse the repository at this point in the history
…n load/store operations (triton-lang#1427)
  • Loading branch information
Jokeren authored Mar 28, 2023
1 parent fe76b12 commit adc4d25
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 35 deletions.
9 changes: 8 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ struct PTXBuilder {

// Create a new operand which is written to, that is, the constraint starts
// with "=", e.g. "=r".
Operand *newOperand(StringRef constraint);
// If the operand will be used in predicated execution,
// users may want to initialize it before use.
// Otherwise if the register is only used in the true branch or the false
// branch but not both, the register is undefined and ptxas can perform
// aggressive optimizations that may lead to incorrect results.
Operand *newOperand(StringRef constraint, bool init = false);

// Create a constant integer operand.
Operand *newConstantOperand(int64_t v);
Expand All @@ -171,6 +176,8 @@ struct PTXBuilder {
return argArchive.back().get();
}

void initOperand(Operand *opr);

// Make the operands in argArchive follow the provided \param order.
void reorderArgArchive(ArrayRef<Operand *> order) {
assert(order.size() == argArchive.size());
Expand Down
65 changes: 33 additions & 32 deletions lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct LoadOpConversion
}

// vectorized iteration through all the pointer/mask/other elements
const int valueElemNbits =
const int valueElemNBits =
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
const int numVecs = numElems / vec;

Expand All @@ -117,11 +117,11 @@ struct LoadOpConversion
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;

const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
const size_t totalWidth = valueElemNBits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
const size_t wordNElems = width / valueElemNBits;
const size_t movWidth = width < 16 ? 16 : width;
assert(wordNElems * nWords * numVecs == numElems);

Expand All @@ -138,18 +138,12 @@ struct LoadOpConversion
const std::string writeConstraint =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");

PTXInstr &init =
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
PTXInstr::Operand *zero = ptxBuilder.newConstantOperand(0);

// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
auto *opr = ptxBuilder.newOperand(writeConstraint,
/*init=*/true); // =r operations
dstsOpr->listAppend(opr);
// Initialize the destination register, otherwise the register will
// be undefined if the predicate is false.
init(opr, zero);
}

auto *addrOpr =
Expand Down Expand Up @@ -186,7 +180,7 @@ struct LoadOpConversion
PTXInstr &mov =
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));

size_t size = width / valueElemNbits;
size_t size = width / valueElemNBits;

auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = undef(vecTy);
Expand All @@ -201,8 +195,8 @@ struct LoadOpConversion
PTXInstr::Operand *opr{};

if (otherIsSplatConstInt) {
for (size_t s = 0; s < 32; s += valueElemNbits)
splatVal |= splatVal << valueElemNbits;
for (size_t s = 0; s < 32; s += valueElemNBits)
splatVal |= splatVal << valueElemNBits;
opr = ptxBuilder.newConstantOperand(splatVal);
} else
opr = ptxBuilder.newOperand(v, readConstraint);
Expand Down Expand Up @@ -233,10 +227,10 @@ struct LoadOpConversion
curr = ret;
}
curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy,
width / valueElemNbits));
width / valueElemNBits));
rets.push_back(curr);
}
int tmp = width / valueElemNbits;
int tmp = width / valueElemNBits;
for (size_t ii = 0; ii < vec; ++ii) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
Expand Down Expand Up @@ -312,18 +306,18 @@ struct StoreOpConversion

const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNbits = dtsize * 8;
const size_t valueElemNBits = dtsize * 8;

const int numVecs = elemsPerThread / vec;
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;

const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
const size_t totalWidth = valueElemNBits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
const size_t wordNElems = width / valueElemNBits;
assert(wordNElems * nWords * numVecs == elemsPerThread);

// TODO(Superjomn) Add cache policy fields to StoreOp.
Expand Down Expand Up @@ -414,6 +408,7 @@ struct AtomicCASOpConversion
Type valueElemTy =
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
PTXBuilder ptxBuilderMemfence;
Expand All @@ -424,13 +419,12 @@ struct AtomicCASOpConversion

Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));

Value casPtr = ptrElements[0];
Value casCmp = cmpElements[0];
Value casVal = valElements[0];

PTXBuilder ptxBuilderAtomicCAS;
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r");
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r", /*init=*/true);
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
Expand All @@ -441,7 +435,7 @@ struct AtomicCASOpConversion
barrier();

PTXBuilder ptxBuilderStore;
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l");
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
Expand Down Expand Up @@ -498,7 +492,7 @@ struct AtomicRMWOpConversion
Type valueElemTy =
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
Expand Down Expand Up @@ -529,16 +523,16 @@ struct AtomicRMWOpConversion
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
std::string tyId = valueElemNbits * vec == 64
std::string tyId = valueElemNBits * vec == 64
? "l"
: (valueElemNbits * vec == 32 ? "r" : "h");
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId);
: (valueElemNBits * vec == 32 ? "r" : "h");
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true);
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);

auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
auto sBits = std::to_string(valueElemNbits);
auto sBits = std::to_string(valueElemNBits);
switch (atomicRmwAttr) {
case RMWOp::AND:
sTy = "b" + sBits;
Expand All @@ -554,9 +548,9 @@ struct AtomicRMWOpConversion
break;
case RMWOp::FADD:
rmwOp = "add";
rmwOp += (valueElemNbits == 16 ? ".noftz" : "");
rmwOp += (valueElemNBits == 16 ? ".noftz" : "");
sTy = "f" + sBits;
sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : "";
sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : "";
break;
case RMWOp::MAX:
sTy = "s" + sBits;
Expand Down Expand Up @@ -598,7 +592,14 @@ struct AtomicRMWOpConversion
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
store(old, atomPtr);
// Only threads with rmwMask = True store the result
PTXBuilder ptxBuilderStore;
auto &storeShared =
ptxBuilderStore.create<>("st")->shared().o("b" + sBits);
auto *ptrOpr = ptxBuilderStore.newAddrOperand(atomPtr, "r");
auto *valOpr = ptxBuilderStore.newOperand(old, tyId);
storeShared(ptrOpr, valOpr).predicate(rmwMask);
ptxBuilderStore.launch(rewriter, loc, void_ty(ctx));
barrier();
Value ret = load(atomPtr);
barrier();
Expand Down
26 changes: 24 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/PTXAsmFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,34 @@ PTXBuilder::newOperand(mlir::Value value, StringRef constraint,
return opr;
}

PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint) {
void PTXBuilder::initOperand(Operand *opr) {
auto numBits = 0;
// Derive numBits from the constraint.
if (opr->constraint[1] == 'c' || opr->constraint[1] == 'h')
numBits = 16;
else if (opr->constraint[1] == 'r')
numBits = 32;
else if (opr->constraint[1] == 'l')
numBits = 64;
else
llvm_unreachable(("Unknown constraint: " + opr->constraint).c_str());
// If numBits is less than 16, we use 16 as default because PTX does not
// support 8-bit mov.
numBits = numBits < 16 ? 16 : numBits;
auto *zero = newConstantOperand(0);
auto &init = create<>("mov")->o("u" + std::to_string(numBits));
init(opr, zero);
}

PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint, bool init) {
// Constraint should be something like "=r"
assert(!constraint.empty() && constraint[0] == '=');
assert(constraint.size() == 2 && constraint[0] == '=');
auto *opr = newOperand();
opr->idx = oprCounter++;
opr->constraint = constraint;
if (init) {
initOperand(opr);
}
return opr;
}

Expand Down
11 changes: 11 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,17 @@ def kernel(X, Z):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)


def test_atomic_rmw_predicate(device="cuda"):
@triton.jit
def kernel(X):
val = tl.program_id(0)
if val < 64:
tl.atomic_max(X, val)
x = torch.zeros((1,), device=device, dtype=torch.int32)
kernel[(4096,)](x)
assert x.item() == 63


@pytest.mark.parametrize("shape, axis",
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
Expand Down

0 comments on commit adc4d25

Please sign in to comment.