Skip to content

Commit

Permalink
[NFC] Generalize some DotTrait handling (triton-lang#4098)
Browse files Browse the repository at this point in the history
Remove some unnecessary restrictions on code related to DotLike.

Only Hopper Dot can take shared memory operands so remove some redundant
checks.
  • Loading branch information
ThomasRaoux authored Jun 7, 2024
1 parent b90b3a0 commit da99faf
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 6 deletions.
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DotLike : public TraitBase<ConcreteType, DotLike> {
return op->emitOpError("expected 3 operands");
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
auto cTy = cast<TensorType>(op->getOperand(2).getType());
auto cTy = cast<TensorOrMemDesc>(op->getOperand(2).getType());
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
auto cShape = cTy.getShape();
Expand Down
5 changes: 0 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,6 @@ class FuseTransHopper : public OpRewritePattern<LocalAllocOp> {
return failure();

auto dot = *allocOp->getUsers().begin();
auto dotEnc = dyn_cast<NvidiaMmaEncodingAttr>(
cast<RankedTensorType>(dot->getResult(0).getType()).getEncoding());
if (!dotEnc || dotEnc.getVersionMajor() != 3)
return failure();

if (!allocOp.getSrc())
return failure();

Expand Down

0 comments on commit da99faf

Please sign in to comment.