Skip to content

Commit

Permalink
Use dynamic dispatch instead of templated compile time traits, but ke…
Browse files Browse the repository at this point in the history
…ep separation between

1. Tensor memory access atom derived message constants
2. Workload derived message constraints
  • Loading branch information
csullivan committed Feb 19, 2025
1 parent 96b2b71 commit f578d4d
Showing 1 changed file with 127 additions and 132 deletions.
259 changes: 127 additions & 132 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,44 @@ static constexpr int narrowingFactor = 4;

namespace {

// Tensor memory access traits, currently only 32b is used
template <bool IsStrided> struct TMemAccess32b {
static constexpr int opBitWidth = 32;
static constexpr int colsPerThread = 1;
static constexpr int rowsPerThread = 1;
static constexpr const char *opShape = IsStrided ? "16x32bx2" : "32x32b";
static constexpr bool usesSecondHalfOffset = IsStrided;
struct TMemAccessAtom {
int opBitWidth;
int colsPerThread;
int rowsPerThread;
const char *opShape;
bool usesSecondHalfOffset;
};

struct TMemAccess256b {
static constexpr int opBitWidth = 256;
static constexpr int colsPerThread = 2;
static constexpr int rowsPerThread = 2;
static constexpr char opShape[] = "16x256b";
constexpr TMemAccessAtom TMemAccess16x32bx2{32, 1, 1, "16x32bx2", true};
constexpr TMemAccessAtom TMemAccess32x32b{32, 1, 1, "32x32b", false};
constexpr TMemAccessAtom TMemAccess16x256b{256, 2, 2, "16x256b", false};

struct TMemMessageTraits {
TMemAccessAtom atom;
bool usesSecondHalfOffset;
int numThreadsPerWarp;
int maxNumRepeats;
int maxCols;
int numRows;
int numCols;
int numRepeats;
int numRegs;
};

template <typename TMemAccess> struct TMemMessage {
using traits = TMemAccess;
static constexpr int opBitWidth = TMemAccess::opBitWidth;
static constexpr int colsPerThread = TMemAccess::colsPerThread;
static constexpr int rowsPerThread = TMemAccess::rowsPerThread;
static constexpr int numThreadsPerWarp = 32;
static constexpr int maxNumRepeats =
largestTmemLoadStore / (colsPerThread * rowsPerThread);
static constexpr int maxColsPerMessage = (opBitWidth / 32) * maxNumRepeats;
static constexpr int rowsPerMessage = numThreadsPerWarp / rowsPerThread;
static constexpr int colsPerMessage = maxColsPerMessage / narrowingFactor;
};
TMemMessageTraits getTMemMessageFromAtom(const TMemAccessAtom &atom) {
TMemMessageTraits m;
m.atom = atom;
m.usesSecondHalfOffset = atom.usesSecondHalfOffset;
m.numThreadsPerWarp = 32;
m.maxNumRepeats =
largestTmemLoadStore / (atom.colsPerThread * atom.rowsPerThread);
m.maxCols = (atom.opBitWidth / 32) * m.maxNumRepeats;
m.numRows = m.numThreadsPerWarp / atom.rowsPerThread;
m.numCols = m.maxCols / narrowingFactor;
m.numRepeats = m.numCols / (atom.opBitWidth / 32);
m.numRegs = atom.colsPerThread * atom.rowsPerThread * m.numRepeats;
return m;
}

struct TMemRuntimeInfo {
static constexpr int numRowsPerWarp = 32;
Expand All @@ -68,6 +78,29 @@ struct TMemRuntimeInfo {
int colsPerWarpGroup;
};

TMemMessageTraits constrainMessageWidthToWorkload(TMemMessageTraits m,
const TMemRuntimeInfo &info) {
// If the workload runtime requires fewer registers than the default message
// width, use a message width that matches the workload
int maxRegsFromWorkload = info.colsPerWarpGroup;
if (info.unpackedb16)
maxRegsFromWorkload /= info.numElementsPer32B;
if (info.useStridedMessage)
maxRegsFromWorkload /= 2;

m.numRepeats = m.numCols / (m.atom.opBitWidth / 32);
m.numRegs = m.atom.colsPerThread * m.atom.rowsPerThread * m.numRepeats;
m.numRegs = std::min(m.numRegs, maxRegsFromWorkload);
// Invert the above formulas to calculate the effective runtime message width
m.numCols = (m.numRegs * (m.atom.opBitWidth / 32)) /
(m.atom.colsPerThread * m.atom.rowsPerThread);

// Half as many registers are needed for 16-bit packed elements,
// so twice as many columns are accessed per message.
m.numCols *= info.numElementsPer32B;
return m;
}

SmallVector<Value> packToI32(const SmallVector<Value> &values, Location loc,
ConversionPatternRewriter &rewriter) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Expand All @@ -91,13 +124,6 @@ SmallVector<Value> packToI32(const SmallVector<Value> &values, Location loc,
return packedValues;
}

int getNum32BRegs(bool unpackedb16, int numElementsPer32B, int numCols) {
int numRegPerMessage = numCols;
if (unpackedb16)
numRegPerMessage = numRegPerMessage / numElementsPer32B;
return numRegPerMessage;
}

TMemRuntimeInfo getTMemRuntimeInfo(Operation *op, RankedTensorType tensorType,
MemDescType memType) {
TMemRuntimeInfo info;
Expand Down Expand Up @@ -154,102 +180,81 @@ TMemRuntimeInfo getTMemRuntimeInfo(Operation *op, RankedTensorType tensorType,
return info;
}

template <typename TMemMsgT>
void calculateAddressAndEmitTmemMessage(
Location loc, Value baseAddress, TMemRuntimeInfo info,
Location loc, Value baseAddress, TMemRuntimeInfo runtime,
const TMemMessageTraits &defaultMessage,
ConversionPatternRewriter &rewriter,
const std::function<void(Value, int, bool, int, bool, int)>
&createMemoryOp) {
const std::function<void(Value, int, bool, int, bool)> &createMemoryOp) {

TritonLLVMOpBuilder b(loc, rewriter);
Value warpId = rewriter.create<nvgpu::WarpIdOp>(loc);
Value warpIdInGroup = b.urem(warpId, b.i32_val(4));
Value warpGroupId = b.udiv(warpId, b.i32_val(4));

int numRegs = getNum32BRegs(info.unpackedb16, info.numElementsPer32B,
info.colsPerWarpGroup);
if (info.useStridedMessage)
numRegs /= 2;

// If the workload runtime requires fewer registers than the default message
// width, use a message width that matches the workload
int colsPerMessage = TMemMsgT::colsPerMessage;
int numRepeats = colsPerMessage / (TMemMsgT::opBitWidth / 32);
int numRegsPerMsg =
TMemMsgT::colsPerThread * TMemMsgT::rowsPerThread * numRepeats;
numRegsPerMsg = std::min(numRegsPerMsg, numRegs);
// Invert the above formulas to calculate the effective runtime message width
colsPerMessage = (numRegsPerMsg * (TMemMsgT::opBitWidth / 32)) /
(TMemMsgT::colsPerThread * TMemMsgT::rowsPerThread);

// Half as many registers are needed for 16-bit packed elements,
// so twice as many columns are accessed per message.
colsPerMessage *= info.numElementsPer32B;
for (int block = 0; block < info.numBlocks; block += info.numWarpGroups) {
auto message = constrainMessageWidthToWorkload(defaultMessage, runtime);
for (int block = 0; block < runtime.numBlocks;
block += runtime.numWarpGroups) {
Value address = b.ptrtoint(i32_ty, baseAddress);
Value blockId =
b.add(b.i32_val(block),
b.udiv(warpGroupId, b.i32_val(info.numWarpGroupsPerBlock)));
b.udiv(warpGroupId, b.i32_val(runtime.numWarpGroupsPerBlock)));
Value warpGroupIdInBlock =
b.urem(warpGroupId, b.i32_val(info.numWarpGroupsPerBlock));
b.urem(warpGroupId, b.i32_val(runtime.numWarpGroupsPerBlock));
Value startColumnId =
b.mul(warpGroupIdInBlock, b.i32_val(info.colsPerWarpGroup));
b.mul(warpGroupIdInBlock, b.i32_val(runtime.colsPerWarpGroup));
Value blockRowId =
b.mul(warpIdInGroup, b.i32_val(TMemRuntimeInfo::numRowsPerWarp));

if (info.blocksInterleaved) {
if (runtime.blocksInterleaved) {
Value blockIdIsOdd = b.urem(blockId, b.i32_val(2));
Value blockIdPrevEven = b.sub(blockId, blockIdIsOdd);
blockRowId = b.add(blockRowId, b.mul(blockIdIsOdd, b.i32_val(16)));
startColumnId =
b.add(startColumnId,
b.mul(blockIdPrevEven, b.i32_val(info.numColsPerBlock / 2)));
b.mul(blockIdPrevEven, b.i32_val(runtime.numColsPerBlock / 2)));
} else {
startColumnId =
b.add(startColumnId, b.mul(blockId, b.i32_val(info.numColsPerBlock)));
startColumnId = b.add(startColumnId,
b.mul(blockId, b.i32_val(runtime.numColsPerBlock)));
}

// A strided message accesses twice as many columns per message,
// thus half as many messages are required
int numColumns = info.useStridedMessage ? info.numColsPerBlock / 2
: info.numColsPerBlock;
for (int colStart = 0; colStart < numColumns; colStart += colsPerMessage) {
int numColumns = runtime.useStridedMessage ? runtime.numColsPerBlock / 2
: runtime.numColsPerBlock;
for (int colStart = 0; colStart < numColumns; colStart += message.numCols) {
// For messages that span only 16 rows (e.g. 16x256b), multiple messages
// are required to cover the entire set of rows per warp.
for (int rowStart = 0; rowStart < TMemRuntimeInfo::numRowsPerWarp;
rowStart += TMemMsgT::rowsPerMessage) {
rowStart += message.numRows) {
Value rowOffset = b.add(blockRowId, b.i32_val(rowStart));
Value warpGroupAddress =
b.add(address, b.shl(rowOffset, b.i32_val(16)));
warpGroupAddress = b.add(warpGroupAddress, startColumnId);

Value msgAddress = b.add(warpGroupAddress, b.i32_val(colStart));
int secondHalfColOffset = 0;
if (info.useStridedMessage) {
if (runtime.useStridedMessage) {
// Offset to half way through the set of columns for this warpgroup.
secondHalfColOffset = numColumns;
}
createMemoryOp(msgAddress, secondHalfColOffset, info.unpackedb16,
numRegsPerMsg, info.useStridedMessage,
TMemMsgT::opBitWidth);
createMemoryOp(msgAddress, secondHalfColOffset, runtime.unpackedb16,
message.numRegs, runtime.useStridedMessage);
}
}
}
}

template <typename TMemMsgT>
static void
createTensorMemoryStore(Location loc, Value address, SmallVector<Value> &srcs,
int secondHalfOffset, Value pred, bool unpacked,
ConversionPatternRewriter &rewriter) {
void createTensorMemoryStore(Location loc, Value address,
SmallVector<Value> &srcs, int secondHalfOffset,
Value pred, bool unpacked,
const TMemAccessAtom &atom,
ConversionPatternRewriter &rewriter) {
PTXBuilder ptxBuilder;
std::string opcode;
std::string packedStr = unpacked ? ".unpack::16b" : "";
unsigned numRepeats = srcs.size() / (TMemMsgT::traits::rowsPerThread *
TMemMsgT::traits::colsPerThread);
opcode = "@$0 tcgen05.st.sync.aligned." +
std::string(TMemMsgT::traits::opShape) + ".x" +
std::to_string(numRepeats) + packedStr;
unsigned numRepeats = srcs.size() / (atom.rowsPerThread * atom.colsPerThread);
std::string opcode = "@$0 tcgen05.st.sync.aligned." +
std::string(atom.opShape) + ".x" +
std::to_string(numRepeats) + packedStr;
if (secondHalfOffset)
opcode += ".b32 [$1], " + std::to_string(secondHalfOffset) + ", {";
else
Expand Down Expand Up @@ -299,11 +304,9 @@ static void reorderScales(SmallVector<Value> &srcValues, int64_t k) {
srcValues = std::move(reorderedValues);
}

template <typename Fn> void emitMemoryOp(bool isStrided, Fn &&emitMemoryOpFn) {
if (isStrided)
emitMemoryOpFn(std::integral_constant<bool, true>{});
else
emitMemoryOpFn(std::integral_constant<bool, false>{});
TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info) {
auto atom = info.useStridedMessage ? TMemAccess16x32bx2 : TMemAccess32x32b;
return getTMemMessageFromAtom(atom);
}

static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src,
Expand All @@ -321,24 +324,22 @@ static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src,
reorderScales(srcValues, dstType.getShape().back());
}

auto tmemInfo = getTMemRuntimeInfo(op, cast<RankedTensorType>(src.getType()),
cast<MemDescType>(dest.getType()));
auto info = getTMemRuntimeInfo(op, cast<RankedTensorType>(src.getType()),
cast<MemDescType>(dest.getType()));
const TMemMessageTraits message = selectTMemMessage(info);
int regIdx = 0;
emitMemoryOp(tmemInfo.useStridedMessage, [&](auto isStrided) {
using MsgT = TMemMessage<TMemAccess32b<decltype(isStrided)::value>>;
calculateAddressAndEmitTmemMessage<MsgT>(
loc, tmemBase, tmemInfo, rewriter,
[&](Value startAddress, int secondHalfColOffset, bool unpackedb16,
int regsPerMsg, bool useStridedMessage, int opBitWidth) {
SmallVector<Value> srcValuesSlice(srcValues.begin() + regIdx,
srcValues.begin() + regIdx +
regsPerMsg);
regIdx += regsPerMsg;
createTensorMemoryStore<MsgT>(loc, startAddress, srcValuesSlice,
secondHalfColOffset, pred, unpackedb16,
rewriter);
});
});
calculateAddressAndEmitTmemMessage(
loc, tmemBase, info, message, rewriter,
[&](Value startAddress, int secondHalfColOffset, bool unpackedb16,
int regsPerMsg, bool useStridedMessage) {
SmallVector<Value> srcValuesSlice(srcValues.begin() + regIdx,
srcValues.begin() + regIdx +
regsPerMsg);
regIdx += regsPerMsg;
createTensorMemoryStore(loc, startAddress, srcValuesSlice,
secondHalfColOffset, pred, unpackedb16,
message.atom, rewriter);
});
createWaitOpSt(loc, rewriter);

// Emit a barrier to ensure all threads have finished writing to tensor memory
Expand Down Expand Up @@ -383,20 +384,17 @@ struct TensorMemoryAllocOpConversion
}
};

template <typename TMemMsgT>
static Value createTensorMemoryLoad(Location loc,
triton::nvidia_gpu::TMEMLoadOp op,
Value address, int secondHalfOffset,
bool unpacked, int numRegPerMessage,
ConversionPatternRewriter &rewriter) {
Value createTensorMemoryLoad(Location loc, triton::nvidia_gpu::TMEMLoadOp op,
Value address, int secondHalfOffset, bool unpacked,
int numRegPerMessage, const TMemAccessAtom &atom,
ConversionPatternRewriter &rewriter) {
PTXBuilder ptxBuilder;
std::string opcode;
// If the memory is unpacked we need to pack on the fly when loading.
std::string packedStr = unpacked ? ".pack::16b" : "";
unsigned numRepeats = numRegPerMessage / (TMemMsgT::traits::rowsPerThread *
TMemMsgT::traits::colsPerThread);
opcode = "tcgen05.ld.sync.aligned." + std::string(TMemMsgT::traits::opShape) +
".x" + std::to_string(numRepeats) + packedStr + ".b32 {";
unsigned numRepeats =
numRegPerMessage / (atom.rowsPerThread * atom.colsPerThread);
std::string opcode = "tcgen05.ld.sync.aligned." + std::string(atom.opShape) +
".x" + std::to_string(numRepeats) + packedStr + ".b32 {";

SmallVector<PTXInstr::Operand *> operands;
for (int i = 0; i < numRegPerMessage; i++) {
Expand Down Expand Up @@ -464,25 +462,22 @@ struct TensorMemoryLoadOpConversion
getTypeConverter()->convertType(op.getSrc().getType().getElementType());
auto tmemBase = adaptor.getSrc();

auto tmemInfo =
getTMemRuntimeInfo(op, cast<RankedTensorType>(op.getType()),
cast<MemDescType>(op.getSrc().getType()));
auto info = getTMemRuntimeInfo(op, cast<RankedTensorType>(op.getType()),
cast<MemDescType>(op.getSrc().getType()));
const TMemMessageTraits message = selectTMemMessage(info);
SmallVector<Value> resultVals;
emitMemoryOp(tmemInfo.useStridedMessage, [&](auto isStrided) {
using MsgT = TMemMessage<TMemAccess32b<decltype(isStrided)::value>>;
calculateAddressAndEmitTmemMessage<MsgT>(
loc, tmemBase, tmemInfo, rewriter,
[&](Value startAddress, int secondHalfColOffset, bool unpackedb16,
int regsPerMessage, bool useStridedMessage, int opBitWidth) {
Value packedValues = createTensorMemoryLoad<MsgT>(
loc, op, startAddress, secondHalfColOffset, unpackedb16,
regsPerMessage, rewriter);
auto results =
unpackResults(packedValues, op.getType().getElementType(),
regsPerMessage, loc, rewriter);
resultVals.append(results.begin(), results.end());
});
});
calculateAddressAndEmitTmemMessage(
loc, tmemBase, info, message, rewriter,
[&](Value startAddress, int secondHalfColOffset, bool unpackedb16,
int regsPerMessage, bool useStridedMessage) {
Value packedValues = createTensorMemoryLoad(
loc, op, startAddress, secondHalfColOffset, unpackedb16,
regsPerMessage, message.atom, rewriter);
auto results =
unpackResults(packedValues, op.getType().getElementType(),
regsPerMessage, loc, rewriter);
resultVals.append(results.begin(), results.end());
});
Type structTy = getTypeConverter()->convertType(op.getType());
Value resultStruct =
packLLElements(loc, getTypeConverter(), resultVals, rewriter, structTy);
Expand Down

0 comments on commit f578d4d

Please sign in to comment.