Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LAYOUTS] Allow DistributedEncoding attributes to override get[Total]ElemsPerThread() #5980

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

chsigg
Copy link
Collaborator

@chsigg chsigg commented Feb 21, 2025

Change ODS to provide defaultImplementation, not methodBody. Telling from the comment, this was a accidental.

In the get[Total]ElemsPerThread() free functions, use the DistributedEncodingTrait if implemented.

This downstream projects (specifically, OpenXLA's SparseDotMetaEncoding) want to override these functions, but this was accidentally broken by 61b5674.

…ElemsPerThread()

Change ODS to provide defaultImplementation, not methodBody. Telling from the comment, this was a accidental.

In the get[Total]ElemsPerThread() free functions, use the DistributedEncodingTrait if implemented.

This downstream projects (specifically, OpenXLA's [SparseDotMetaEncoding](https://github.com/openxla/xla/blob/6772834e77115e7368a418ef71024540274c93b2/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.td#L22)) want to override these functions, but this was accidentally broken by triton-lang@61b5674.
@chsigg chsigg requested a review from ptillet as a code owner February 21, 2025 12:48
@chsigg chsigg requested review from lezcano and removed request for ptillet February 21, 2025 12:49
Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit, otherwise LGTM

Comment on lines 42 to 44
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
return distLayout.getTotalElemsPerThread(shape);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can even assert that this is indeed a DistributedEncoding as otherwise this function does not make any sense. Either that or directly change the input argument to have the right type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like:

unsigned getTotalElemsPerThread(RankedTensorType t) {
  auto layout = cast<DistributedEncodingTrait>(t.getEncoding());
  return layout.getTotalElemsPerThread(t.getShape());
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -39,11 +39,17 @@ LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape) {
}

unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
return distLayout.getTotalElemsPerThread(shape);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it returns a different value that what the linear layout does? This means you have a distributed layout that is not a linear layout?

@@ -39,11 +39,17 @@ LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape) {
}

unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
return distLayout.getTotalElemsPerThread(shape);
}
return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like this path is not reachable, can we remove it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, if you mean the one where layout is not a DistributedEncodingTrait.

return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape);
}

SmallVector<unsigned> getElemsPerThread(Attribute layout,
ArrayRef<int64_t> shape) {
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
return distLayout.getElemsPerThread(shape);
}
return toLinearEncoding(layout, shape).getElemsPerThread(shape);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@ThomasRaoux
Copy link
Collaborator

@chsigg, could you clarify if your downstream layout is a linear layout or not?

@lezcano
Copy link
Contributor

lezcano commented Feb 21, 2025

In particular, whether it's a linear layout such that all its bases are either a power of two or zero (i.e. a DistributedEncoding)

@chsigg
Copy link
Collaborator Author

chsigg commented Feb 21, 2025

It is, but the the number of elements is different. That's why we would like to overload this function.

@ThomasRaoux
Copy link
Collaborator

It is, but the the number of elements is different. That's why we would like to overload this function.

how can the layout be a linear layout but the number of elements be different than what is calculated by linear layout?

@chsigg
Copy link
Collaborator Author

chsigg commented Feb 21, 2025

Sorry, I'm not really able to answer your questions well. The sparse metadata converts to linear layout by converting its parent layout. Maybe this is not correct, I'm not really familiar with the code.

The llvm type converter used to go through the interface's getTotalNumElements(), and I'm not sure why it no longer does but the method is still there. Also, from the code it looks like the interface method is intended to be overridable, but it's not.

I really don't mind going either way with this change, I was just trying to fix a bug.

@ThomasRaoux
Copy link
Collaborator

Also, from the code it looks like the interface method is intended to be overridable, but it's not.

Yes it makes sense to have the interface overridable. Since we are moving all our layouts to be based on linear layout having layouts that don't follow the same property is likely to have other problems which is why I ask. If the layout is not completely representable by LinearLayout I expect so other things will break.

This patch itself is fine but I think supporting layouts that don't map to linear layout is going to be a challenge (if this is what you need)

@chsigg
Copy link
Collaborator Author

chsigg commented Feb 22, 2025

Thanks Thomas. I think the sparse metadata encoding should work fine with your plan to move to linear layouts everywhere. My relatively random attempts to fix toLinearLayout() for our attribute weren't successful yet, and there are too many other time sensitive tasks for me to look into it properly at the moment. I understand that we will need to do this soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants