-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Blackwell] Support narrower TMEM messages and shapes #5945
[Blackwell] Support narrower TMEM messages and shapes #5945
Conversation
I'm curious in general what is the expected performance difference? Is the only difference to reduce the number of instructions issued? Do you know why the larger messages exist if they are usually not better?
That stills feels like a big gap. Is it for a persistent loop kernel where the descriptor set is pulled outside the loop? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly style related comments.
It would be great to remove all the templates, it feels a bit unnecessary as there is no performance need for it to be done at compile time and it makes the code more complicated IMO.
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp
Outdated
Show resolved
Hide resolved
message widths to alleviate register pressure, as well as additional shapes * Adds tcgen05.st/ld..16x256b shaped codegen support. With subsequent work this can pair with downstream stmatrix ops for lower latency epilogues
b9eac6e
to
f578d4d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, could you answer those questions when you get a chance: #5945 (comment)?
…ep separation between 1. Tensor memory access atom derived message constants 2. Workload derived message constraints Also remove the hard coded narrowing factor. Instead pick it so that the largest message size which avoids thread register sautration is used.
f578d4d
to
777d917
Compare
This is a good point and I spent some time comparing latencies. In general, as long as we don't induce spilling, the wider message will provide better latency if all other things are equal. Ideally, the message width is chosen in order to latency match subsequent stores through the hierarchy so that separate epilogue subtiles are pipelined together from TMEM->SMEM->GMEM. For now that's out of scope in triton -- though I want to bring it up again soon :). That said, the perf impact of using narrow messages even for small K is quite low -- 2-3%. Nonetheless, I was compelled by the consideration of your question to instead remove the hard coded narrowing factor and replace it with narrowing that only occurs when a single message would require half of the available per thread registers [1]. This means the existing runtimes will be unchanged except for the cases in which 256 thread regs are needed, and only in that case will the message size be narrowed. As a result, I removed the changes to the tests and added a new lit test that reflects a case when we expect narrowing to occur (e.g. 128x256).
That was not for a persistent kernel, just the block scaled tutorial kernel. It's on my perf backlog now to evaluate and I plan to prioritize it based on estimated perf impact relative to other investigations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Refactors tensor memory message emission to support narrower message widths to alleviate register pressure.
For example, I found that when using the on device TMA descriptors the additional registers they require was sufficient to push the already very tight register budget over the edge for for fp4 kernels where BLOCK_N=256 is optimal. For BLOCK_N = 256, the epilogue currently emits two tcgen05.ld..32x32b.x128 which each require 128 registers and then spills these to thread local memory in order to cast and pack them into fewer output registers. The result is that the on device TMA descriptor creation runtime is about 40-50% slower than when using host TMA descriptors.
With this PR we emit messages each with a smaller width (fewer repeats), with very little performance degredation. In that case the same number of registers are required by the epilogue, but the ptxas backend is better able to interleave these loads with casting and packing instructions that alleviate the pressure and effectively avoid any spills. In this case, the fp4 runtime with on device TMA descriptors is only ~10% slower than the host side equivalent.
The PR also makes the message emission generic over the instruction shape and adds a new specialization to support
tcgen05.st/ld..16x256b
. With subsequent work this can pair with downstream stmatrix ops for lower latency epilogues.I updated the corresponding lit tests to match the narrowed tensor memory loads and stores.
cc @pawelszczerbuk @ThomasRaoux