From f578d4d84489474f940f50c4cd89ee6ce1d3975b Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 19 Feb 2025 21:47:33 +0000 Subject: [PATCH] Use dynamic dispatch instead of templated compile time traits, but keep separation between 1. Tensor memory access atom derived message constants 2. Workload derived message constraints --- .../TensorMemoryToLLVM.cpp | 259 +++++++++--------- 1 file changed, 127 insertions(+), 132 deletions(-) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index 3de0dbacf038..62876f648189 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -21,34 +21,44 @@ static constexpr int narrowingFactor = 4; namespace { -// Tensor memory access traits, currently only 32b is used -template struct TMemAccess32b { - static constexpr int opBitWidth = 32; - static constexpr int colsPerThread = 1; - static constexpr int rowsPerThread = 1; - static constexpr const char *opShape = IsStrided ? "16x32bx2" : "32x32b"; - static constexpr bool usesSecondHalfOffset = IsStrided; +struct TMemAccessAtom { + int opBitWidth; + int colsPerThread; + int rowsPerThread; + const char *opShape; + bool usesSecondHalfOffset; }; -struct TMemAccess256b { - static constexpr int opBitWidth = 256; - static constexpr int colsPerThread = 2; - static constexpr int rowsPerThread = 2; - static constexpr char opShape[] = "16x256b"; +constexpr TMemAccessAtom TMemAccess16x32bx2{32, 1, 1, "16x32bx2", true}; +constexpr TMemAccessAtom TMemAccess32x32b{32, 1, 1, "32x32b", false}; +constexpr TMemAccessAtom TMemAccess16x256b{256, 2, 2, "16x256b", false}; + +struct TMemMessageTraits { + TMemAccessAtom atom; + bool usesSecondHalfOffset; + int numThreadsPerWarp; + int maxNumRepeats; + int maxCols; + int numRows; + int numCols; + int numRepeats; + int numRegs; }; -template struct TMemMessage { - using traits = TMemAccess; - static constexpr int opBitWidth = TMemAccess::opBitWidth; - static constexpr int colsPerThread = TMemAccess::colsPerThread; - static constexpr int rowsPerThread = TMemAccess::rowsPerThread; - static constexpr int numThreadsPerWarp = 32; - static constexpr int maxNumRepeats = - largestTmemLoadStore / (colsPerThread * rowsPerThread); - static constexpr int maxColsPerMessage = (opBitWidth / 32) * maxNumRepeats; - static constexpr int rowsPerMessage = numThreadsPerWarp / rowsPerThread; - static constexpr int colsPerMessage = maxColsPerMessage / narrowingFactor; -}; +TMemMessageTraits getTMemMessageFromAtom(const TMemAccessAtom &atom) { + TMemMessageTraits m; + m.atom = atom; + m.usesSecondHalfOffset = atom.usesSecondHalfOffset; + m.numThreadsPerWarp = 32; + m.maxNumRepeats = + largestTmemLoadStore / (atom.colsPerThread * atom.rowsPerThread); + m.maxCols = (atom.opBitWidth / 32) * m.maxNumRepeats; + m.numRows = m.numThreadsPerWarp / atom.rowsPerThread; + m.numCols = m.maxCols / narrowingFactor; + m.numRepeats = m.numCols / (atom.opBitWidth / 32); + m.numRegs = atom.colsPerThread * atom.rowsPerThread * m.numRepeats; + return m; +} struct TMemRuntimeInfo { static constexpr int numRowsPerWarp = 32; @@ -68,6 +78,29 @@ struct TMemRuntimeInfo { int colsPerWarpGroup; }; +TMemMessageTraits constrainMessageWidthToWorkload(TMemMessageTraits m, + const TMemRuntimeInfo &info) { + // If the workload runtime requires fewer registers than the default message + // width, use a message width that matches the workload + int maxRegsFromWorkload = info.colsPerWarpGroup; + if (info.unpackedb16) + maxRegsFromWorkload /= info.numElementsPer32B; + if (info.useStridedMessage) + maxRegsFromWorkload /= 2; + + m.numRepeats = m.numCols / (m.atom.opBitWidth / 32); + m.numRegs = m.atom.colsPerThread * m.atom.rowsPerThread * m.numRepeats; + m.numRegs = std::min(m.numRegs, maxRegsFromWorkload); + // Invert the above formulas to calculate the effective runtime message width + m.numCols = (m.numRegs * (m.atom.opBitWidth / 32)) / + (m.atom.colsPerThread * m.atom.rowsPerThread); + + // Half as many registers are needed for 16-bit packed elements, + // so twice as many columns are accessed per message. + m.numCols *= info.numElementsPer32B; + return m; +} + SmallVector packToI32(const SmallVector &values, Location loc, ConversionPatternRewriter &rewriter) { auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -91,13 +124,6 @@ SmallVector packToI32(const SmallVector &values, Location loc, return packedValues; } -int getNum32BRegs(bool unpackedb16, int numElementsPer32B, int numCols) { - int numRegPerMessage = numCols; - if (unpackedb16) - numRegPerMessage = numRegPerMessage / numElementsPer32B; - return numRegPerMessage; -} - TMemRuntimeInfo getTMemRuntimeInfo(Operation *op, RankedTensorType tensorType, MemDescType memType) { TMemRuntimeInfo info; @@ -154,70 +180,52 @@ TMemRuntimeInfo getTMemRuntimeInfo(Operation *op, RankedTensorType tensorType, return info; } -template void calculateAddressAndEmitTmemMessage( - Location loc, Value baseAddress, TMemRuntimeInfo info, + Location loc, Value baseAddress, TMemRuntimeInfo runtime, + const TMemMessageTraits &defaultMessage, ConversionPatternRewriter &rewriter, - const std::function - &createMemoryOp) { + const std::function &createMemoryOp) { TritonLLVMOpBuilder b(loc, rewriter); Value warpId = rewriter.create(loc); Value warpIdInGroup = b.urem(warpId, b.i32_val(4)); Value warpGroupId = b.udiv(warpId, b.i32_val(4)); - int numRegs = getNum32BRegs(info.unpackedb16, info.numElementsPer32B, - info.colsPerWarpGroup); - if (info.useStridedMessage) - numRegs /= 2; - - // If the workload runtime requires fewer registers than the default message - // width, use a message width that matches the workload - int colsPerMessage = TMemMsgT::colsPerMessage; - int numRepeats = colsPerMessage / (TMemMsgT::opBitWidth / 32); - int numRegsPerMsg = - TMemMsgT::colsPerThread * TMemMsgT::rowsPerThread * numRepeats; - numRegsPerMsg = std::min(numRegsPerMsg, numRegs); - // Invert the above formulas to calculate the effective runtime message width - colsPerMessage = (numRegsPerMsg * (TMemMsgT::opBitWidth / 32)) / - (TMemMsgT::colsPerThread * TMemMsgT::rowsPerThread); - - // Half as many registers are needed for 16-bit packed elements, - // so twice as many columns are accessed per message. - colsPerMessage *= info.numElementsPer32B; - for (int block = 0; block < info.numBlocks; block += info.numWarpGroups) { + auto message = constrainMessageWidthToWorkload(defaultMessage, runtime); + for (int block = 0; block < runtime.numBlocks; + block += runtime.numWarpGroups) { Value address = b.ptrtoint(i32_ty, baseAddress); Value blockId = b.add(b.i32_val(block), - b.udiv(warpGroupId, b.i32_val(info.numWarpGroupsPerBlock))); + b.udiv(warpGroupId, b.i32_val(runtime.numWarpGroupsPerBlock))); Value warpGroupIdInBlock = - b.urem(warpGroupId, b.i32_val(info.numWarpGroupsPerBlock)); + b.urem(warpGroupId, b.i32_val(runtime.numWarpGroupsPerBlock)); Value startColumnId = - b.mul(warpGroupIdInBlock, b.i32_val(info.colsPerWarpGroup)); + b.mul(warpGroupIdInBlock, b.i32_val(runtime.colsPerWarpGroup)); Value blockRowId = b.mul(warpIdInGroup, b.i32_val(TMemRuntimeInfo::numRowsPerWarp)); - if (info.blocksInterleaved) { + if (runtime.blocksInterleaved) { Value blockIdIsOdd = b.urem(blockId, b.i32_val(2)); Value blockIdPrevEven = b.sub(blockId, blockIdIsOdd); blockRowId = b.add(blockRowId, b.mul(blockIdIsOdd, b.i32_val(16))); startColumnId = b.add(startColumnId, - b.mul(blockIdPrevEven, b.i32_val(info.numColsPerBlock / 2))); + b.mul(blockIdPrevEven, b.i32_val(runtime.numColsPerBlock / 2))); } else { - startColumnId = - b.add(startColumnId, b.mul(blockId, b.i32_val(info.numColsPerBlock))); + startColumnId = b.add(startColumnId, + b.mul(blockId, b.i32_val(runtime.numColsPerBlock))); } // A strided message accesses twice as many columns per message, // thus half as many messages are required - int numColumns = info.useStridedMessage ? info.numColsPerBlock / 2 - : info.numColsPerBlock; - for (int colStart = 0; colStart < numColumns; colStart += colsPerMessage) { + int numColumns = runtime.useStridedMessage ? runtime.numColsPerBlock / 2 + : runtime.numColsPerBlock; + for (int colStart = 0; colStart < numColumns; colStart += message.numCols) { // For messages that span only 16 rows (e.g. 16x256b), multiple messages // are required to cover the entire set of rows per warp. for (int rowStart = 0; rowStart < TMemRuntimeInfo::numRowsPerWarp; - rowStart += TMemMsgT::rowsPerMessage) { + rowStart += message.numRows) { Value rowOffset = b.add(blockRowId, b.i32_val(rowStart)); Value warpGroupAddress = b.add(address, b.shl(rowOffset, b.i32_val(16))); @@ -225,31 +233,28 @@ void calculateAddressAndEmitTmemMessage( Value msgAddress = b.add(warpGroupAddress, b.i32_val(colStart)); int secondHalfColOffset = 0; - if (info.useStridedMessage) { + if (runtime.useStridedMessage) { // Offset to half way through the set of columns for this warpgroup. secondHalfColOffset = numColumns; } - createMemoryOp(msgAddress, secondHalfColOffset, info.unpackedb16, - numRegsPerMsg, info.useStridedMessage, - TMemMsgT::opBitWidth); + createMemoryOp(msgAddress, secondHalfColOffset, runtime.unpackedb16, + message.numRegs, runtime.useStridedMessage); } } } } -template -static void -createTensorMemoryStore(Location loc, Value address, SmallVector &srcs, - int secondHalfOffset, Value pred, bool unpacked, - ConversionPatternRewriter &rewriter) { +void createTensorMemoryStore(Location loc, Value address, + SmallVector &srcs, int secondHalfOffset, + Value pred, bool unpacked, + const TMemAccessAtom &atom, + ConversionPatternRewriter &rewriter) { PTXBuilder ptxBuilder; - std::string opcode; std::string packedStr = unpacked ? ".unpack::16b" : ""; - unsigned numRepeats = srcs.size() / (TMemMsgT::traits::rowsPerThread * - TMemMsgT::traits::colsPerThread); - opcode = "@$0 tcgen05.st.sync.aligned." + - std::string(TMemMsgT::traits::opShape) + ".x" + - std::to_string(numRepeats) + packedStr; + unsigned numRepeats = srcs.size() / (atom.rowsPerThread * atom.colsPerThread); + std::string opcode = "@$0 tcgen05.st.sync.aligned." + + std::string(atom.opShape) + ".x" + + std::to_string(numRepeats) + packedStr; if (secondHalfOffset) opcode += ".b32 [$1], " + std::to_string(secondHalfOffset) + ", {"; else @@ -299,11 +304,9 @@ static void reorderScales(SmallVector &srcValues, int64_t k) { srcValues = std::move(reorderedValues); } -template void emitMemoryOp(bool isStrided, Fn &&emitMemoryOpFn) { - if (isStrided) - emitMemoryOpFn(std::integral_constant{}); - else - emitMemoryOpFn(std::integral_constant{}); +TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info) { + auto atom = info.useStridedMessage ? TMemAccess16x32bx2 : TMemAccess32x32b; + return getTMemMessageFromAtom(atom); } static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src, @@ -321,24 +324,22 @@ static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src, reorderScales(srcValues, dstType.getShape().back()); } - auto tmemInfo = getTMemRuntimeInfo(op, cast(src.getType()), - cast(dest.getType())); + auto info = getTMemRuntimeInfo(op, cast(src.getType()), + cast(dest.getType())); + const TMemMessageTraits message = selectTMemMessage(info); int regIdx = 0; - emitMemoryOp(tmemInfo.useStridedMessage, [&](auto isStrided) { - using MsgT = TMemMessage>; - calculateAddressAndEmitTmemMessage( - loc, tmemBase, tmemInfo, rewriter, - [&](Value startAddress, int secondHalfColOffset, bool unpackedb16, - int regsPerMsg, bool useStridedMessage, int opBitWidth) { - SmallVector srcValuesSlice(srcValues.begin() + regIdx, - srcValues.begin() + regIdx + - regsPerMsg); - regIdx += regsPerMsg; - createTensorMemoryStore(loc, startAddress, srcValuesSlice, - secondHalfColOffset, pred, unpackedb16, - rewriter); - }); - }); + calculateAddressAndEmitTmemMessage( + loc, tmemBase, info, message, rewriter, + [&](Value startAddress, int secondHalfColOffset, bool unpackedb16, + int regsPerMsg, bool useStridedMessage) { + SmallVector srcValuesSlice(srcValues.begin() + regIdx, + srcValues.begin() + regIdx + + regsPerMsg); + regIdx += regsPerMsg; + createTensorMemoryStore(loc, startAddress, srcValuesSlice, + secondHalfColOffset, pred, unpackedb16, + message.atom, rewriter); + }); createWaitOpSt(loc, rewriter); // Emit a barrier to ensure all threads have finished writing to tensor memory @@ -383,20 +384,17 @@ struct TensorMemoryAllocOpConversion } }; -template -static Value createTensorMemoryLoad(Location loc, - triton::nvidia_gpu::TMEMLoadOp op, - Value address, int secondHalfOffset, - bool unpacked, int numRegPerMessage, - ConversionPatternRewriter &rewriter) { +Value createTensorMemoryLoad(Location loc, triton::nvidia_gpu::TMEMLoadOp op, + Value address, int secondHalfOffset, bool unpacked, + int numRegPerMessage, const TMemAccessAtom &atom, + ConversionPatternRewriter &rewriter) { PTXBuilder ptxBuilder; - std::string opcode; // If the memory is unpacked we need to pack on the fly when loading. std::string packedStr = unpacked ? ".pack::16b" : ""; - unsigned numRepeats = numRegPerMessage / (TMemMsgT::traits::rowsPerThread * - TMemMsgT::traits::colsPerThread); - opcode = "tcgen05.ld.sync.aligned." + std::string(TMemMsgT::traits::opShape) + - ".x" + std::to_string(numRepeats) + packedStr + ".b32 {"; + unsigned numRepeats = + numRegPerMessage / (atom.rowsPerThread * atom.colsPerThread); + std::string opcode = "tcgen05.ld.sync.aligned." + std::string(atom.opShape) + + ".x" + std::to_string(numRepeats) + packedStr + ".b32 {"; SmallVector operands; for (int i = 0; i < numRegPerMessage; i++) { @@ -464,25 +462,22 @@ struct TensorMemoryLoadOpConversion getTypeConverter()->convertType(op.getSrc().getType().getElementType()); auto tmemBase = adaptor.getSrc(); - auto tmemInfo = - getTMemRuntimeInfo(op, cast(op.getType()), - cast(op.getSrc().getType())); + auto info = getTMemRuntimeInfo(op, cast(op.getType()), + cast(op.getSrc().getType())); + const TMemMessageTraits message = selectTMemMessage(info); SmallVector resultVals; - emitMemoryOp(tmemInfo.useStridedMessage, [&](auto isStrided) { - using MsgT = TMemMessage>; - calculateAddressAndEmitTmemMessage( - loc, tmemBase, tmemInfo, rewriter, - [&](Value startAddress, int secondHalfColOffset, bool unpackedb16, - int regsPerMessage, bool useStridedMessage, int opBitWidth) { - Value packedValues = createTensorMemoryLoad( - loc, op, startAddress, secondHalfColOffset, unpackedb16, - regsPerMessage, rewriter); - auto results = - unpackResults(packedValues, op.getType().getElementType(), - regsPerMessage, loc, rewriter); - resultVals.append(results.begin(), results.end()); - }); - }); + calculateAddressAndEmitTmemMessage( + loc, tmemBase, info, message, rewriter, + [&](Value startAddress, int secondHalfColOffset, bool unpackedb16, + int regsPerMessage, bool useStridedMessage) { + Value packedValues = createTensorMemoryLoad( + loc, op, startAddress, secondHalfColOffset, unpackedb16, + regsPerMessage, message.atom, rewriter); + auto results = + unpackResults(packedValues, op.getType().getElementType(), + regsPerMessage, loc, rewriter); + resultVals.append(results.begin(), results.end()); + }); Type structTy = getTypeConverter()->convertType(op.getType()); Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals, rewriter, structTy);