Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Blackwell] Support narrower TMEM messages and shapes #5945

Merged
merged 2 commits into from
Feb 21, 2025

Conversation

csullivan
Copy link
Collaborator

Refactors tensor memory message emission to support narrower message widths to alleviate register pressure.

For example, I found that when using the on device TMA descriptors the additional registers they require was sufficient to push the already very tight register budget over the edge for for fp4 kernels where BLOCK_N=256 is optimal. For BLOCK_N = 256, the epilogue currently emits two tcgen05.ld..32x32b.x128 which each require 128 registers and then spills these to thread local memory in order to cast and pack them into fewer output registers. The result is that the on device TMA descriptor creation runtime is about 40-50% slower than when using host TMA descriptors.

With this PR we emit messages each with a smaller width (fewer repeats), with very little performance degredation. In that case the same number of registers are required by the epilogue, but the ptxas backend is better able to interleave these loads with casting and packing instructions that alleviate the pressure and effectively avoid any spills. In this case, the fp4 runtime with on device TMA descriptors is only ~10% slower than the host side equivalent.

The PR also makes the message emission generic over the instruction shape and adds a new specialization to support tcgen05.st/ld..16x256b. With subsequent work this can pair with downstream stmatrix ops for lower latency epilogues.

I updated the corresponding lit tests to match the narrowed tensor memory loads and stores.

cc @pawelszczerbuk @ThomasRaoux

@csullivan csullivan requested a review from ptillet as a code owner February 17, 2025 23:03
@csullivan csullivan changed the title [Blackwell] Support narrower TMEM messages and 16x256b shapes [Blackwell] Support narrower TMEM messages and shapes Feb 17, 2025
@ThomasRaoux
Copy link
Collaborator

With this PR we emit messages each with a smaller width (fewer repeats), with very little performance degredation

I'm curious in general what is the expected performance difference? Is the only difference to reduce the number of instructions issued? Do you know why the larger messages exist if they are usually not better?

the fp4 runtime with on device TMA descriptors is only ~10% slower than the host side equivalent.

That stills feels like a big gap. Is it for a persistent loop kernel where the descriptor set is pulled outside the loop?

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly style related comments.
It would be great to remove all the templates, it feels a bit unnecessary as there is no performance need for it to be done at compile time and it makes the code more complicated IMO.

message widths to alleviate register pressure, as well as additional shapes
* Adds tcgen05.st/ld..16x256b shaped codegen support. With subsequent work
  this can pair with downstream stmatrix ops for lower latency epilogues
@csullivan csullivan force-pushed the 2025-02-14/tmem_access_lowering branch from b9eac6e to f578d4d Compare February 19, 2025 21:58
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, could you answer those questions when you get a chance: #5945 (comment)?

…ep separation between

1. Tensor memory access atom derived message constants
2. Workload derived message constraints

Also remove the hard coded narrowing factor. Instead pick it so that the
largest message size which avoids thread register sautration is used.
@csullivan csullivan force-pushed the 2025-02-14/tmem_access_lowering branch from f578d4d to 777d917 Compare February 21, 2025 10:02
@csullivan
Copy link
Collaborator Author

@ThomasRaoux,

With this PR we emit messages each with a smaller width (fewer repeats), with very little performance degredation

I'm curious in general what is the expected performance difference? Is the only difference to reduce the number of instructions issued? Do you know why the larger messages exist if they are usually not better?

This is a good point and I spent some time comparing latencies. In general, as long as we don't induce spilling, the wider message will provide better latency if all other things are equal. Ideally, the message width is chosen in order to latency match subsequent stores through the hierarchy so that separate epilogue subtiles are pipelined together from TMEM->SMEM->GMEM. For now that's out of scope in triton -- though I want to bring it up again soon :). That said, the perf impact of using narrow messages even for small K is quite low -- 2-3%.

Nonetheless, I was compelled by the consideration of your question to instead remove the hard coded narrowing factor and replace it with narrowing that only occurs when a single message would require half of the available per thread registers [1]. This means the existing runtimes will be unchanged except for the cases in which 256 thread regs are needed, and only in that case will the message size be narrowed. As a result, I removed the changes to the tests and added a new lit test that reflects a case when we expect narrowing to occur (e.g. 128x256).

the fp4 runtime with on device TMA descriptors is only ~10% slower than the host side equivalent.

That stills feels like a big gap. Is it for a persistent loop kernel where the descriptor set is pulled outside the loop?

That was not for a persistent kernel, just the block scaled tutorial kernel. It's on my perf backlog now to evaluate and I plan to prioritize it based on estimated perf impact relative to other investigations.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@csullivan csullivan merged commit bb78fae into triton-lang:main Feb 21, 2025
7 checks passed
@csullivan csullivan deleted the 2025-02-14/tmem_access_lowering branch February 21, 2025 22:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants