-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LAYOUTS] Allow DistributedEncoding attributes to override get[Total]ElemsPerThread() #5980
base: main
Are you sure you want to change the base?
Conversation
…ElemsPerThread() Change ODS to provide defaultImplementation, not methodBody. Telling from the comment, this was a accidental. In the get[Total]ElemsPerThread() free functions, use the DistributedEncodingTrait if implemented. This downstream projects (specifically, OpenXLA's [SparseDotMetaEncoding](https://github.com/openxla/xla/blob/6772834e77115e7368a418ef71024540274c93b2/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.td#L22)) want to override these functions, but this was accidentally broken by triton-lang@61b5674.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One nit, otherwise LGTM
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) { | ||
return distLayout.getTotalElemsPerThread(shape); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can even assert that this is indeed a DistributedEncoding as otherwise this function does not make any sense. Either that or directly change the input argument to have the right type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something like:
unsigned getTotalElemsPerThread(RankedTensorType t) {
auto layout = cast<DistributedEncodingTrait>(t.getEncoding());
return layout.getTotalElemsPerThread(t.getShape());
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
@@ -39,11 +39,17 @@ LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape) { | |||
} | |||
|
|||
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) { | |||
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) { | |||
return distLayout.getTotalElemsPerThread(shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it returns a different value that what the linear layout does? This means you have a distributed layout that is not a linear layout?
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
@@ -39,11 +39,17 @@ LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape) { | |||
} | |||
|
|||
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) { | |||
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) { | |||
return distLayout.getTotalElemsPerThread(shape); | |||
} | |||
return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it feels like this path is not reachable, can we remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, if you mean the one where layout is not a DistributedEncodingTrait
.
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape); | ||
} | ||
|
||
SmallVector<unsigned> getElemsPerThread(Attribute layout, | ||
ArrayRef<int64_t> shape) { | ||
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) { | ||
return distLayout.getElemsPerThread(shape); | ||
} | ||
return toLinearEncoding(layout, shape).getElemsPerThread(shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@chsigg, could you clarify if your downstream layout is a linear layout or not? |
In particular, whether it's a linear layout such that all its bases are either a power of two or zero (i.e. a DistributedEncoding) |
It is, but the the number of elements is different. That's why we would like to overload this function. |
how can the layout be a linear layout but the number of elements be different than what is calculated by linear layout? |
Sorry, I'm not really able to answer your questions well. The sparse metadata converts to linear layout by converting its parent layout. Maybe this is not correct, I'm not really familiar with the code. The llvm type converter used to go through the interface's I really don't mind going either way with this change, I was just trying to fix a bug. |
Yes it makes sense to have the interface overridable. Since we are moving all our layouts to be based on linear layout having layouts that don't follow the same property is likely to have other problems which is why I ask. If the layout is not completely representable by LinearLayout I expect so other things will break. This patch itself is fine but I think supporting layouts that don't map to linear layout is going to be a challenge (if this is what you need) |
Thanks Thomas. I think the sparse metadata encoding should work fine with your plan to move to linear layouts everywhere. My relatively random attempts to fix |
Change ODS to provide defaultImplementation, not methodBody. Telling from the comment, this was a accidental.
In the get[Total]ElemsPerThread() free functions, use the DistributedEncodingTrait if implemented.
This downstream projects (specifically, OpenXLA's SparseDotMetaEncoding) want to override these functions, but this was accidentally broken by 61b5674.