Skip to content

Commit

Permalink
[SLPVectorizer, TargetTransformInfo, SystemZ] Improve SLP getGatherCo…
Browse files Browse the repository at this point in the history
…st(). (llvm#112491)

As vector element loads are free on SystemZ, this patch improves the cost
computation in getGatherCost() to reflect this.

getScalarizationOverhead() gets an optional parameter which can hold the actual
Values so that they in turn can be passed (by BasicTTIImpl) to
getVectorInstrCost().

SystemZTTIImpl::getVectorInstrCost() will now recognize a LoadInst and
typically return a 0 cost for it, with some exceptions.
  • Loading branch information
JonPsson1 authored Nov 29, 2024
1 parent cbf495f commit 0ad6be1
Show file tree
Hide file tree
Showing 14 changed files with 209 additions and 88 deletions.
19 changes: 11 additions & 8 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -909,11 +909,13 @@ class TargetTransformInfo {

/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
/// extracted from vectors.
/// extracted from vectors. The involved values may be passed in VL if
/// Insert is true.
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind) const;
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {}) const;

/// Estimate the overhead of scalarizing an instructions unique
/// non-constant operands. The (potentially vector) types to use for each of
Expand Down Expand Up @@ -2001,10 +2003,10 @@ class TargetTransformInfo::Concept {
unsigned ScalarOpdIdx) = 0;
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) = 0;
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TargetCostKind CostKind) = 0;
virtual InstructionCost
getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract, TargetCostKind CostKind,
ArrayRef<Value *> VL = {}) = 0;
virtual InstructionCost
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
ArrayRef<Type *> Tys,
Expand Down Expand Up @@ -2585,9 +2587,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TargetCostKind CostKind) override {
TargetCostKind CostKind,
ArrayRef<Value *> VL = {}) override {
return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
CostKind);
CostKind, VL);
}
InstructionCost
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ class TargetTransformInfoImplBase {
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind) const {
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {}) const {
return 0;
}

Expand Down
10 changes: 7 additions & 3 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -780,24 +780,28 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
InstructionCost getScalarizationOverhead(VectorType *InTy,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind) {
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {}) {
/// FIXME: a bitfield is not a reasonable abstraction for talking about
/// which elements are needed from a scalable vector
if (isa<ScalableVectorType>(InTy))
return InstructionCost::getInvalid();
auto *Ty = cast<FixedVectorType>(InTy);

assert(DemandedElts.getBitWidth() == Ty->getNumElements() &&
(VL.empty() || VL.size() == Ty->getNumElements()) &&
"Vector size mismatch");

InstructionCost Cost = 0;

for (int i = 0, e = Ty->getNumElements(); i < e; ++i) {
if (!DemandedElts[i])
continue;
if (Insert)
if (Insert) {
Value *InsertedVal = VL.empty() ? nullptr : VL[i];
Cost += thisT()->getVectorInstrCost(Instruction::InsertElement, Ty,
CostKind, i, nullptr, nullptr);
CostKind, i, nullptr, InsertedVal);
}
if (Extract)
Cost += thisT()->getVectorInstrCost(Instruction::ExtractElement, Ty,
CostKind, i, nullptr, nullptr);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,9 @@ bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(

InstructionCost TargetTransformInfo::getScalarizationOverhead(
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
TTI::TargetCostKind CostKind) const {
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) const {
return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
CostKind);
CostKind, VL);
}

InstructionCost TargetTransformInfo::getOperandsScalarizationOverhead(
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3363,7 +3363,7 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,

InstructionCost AArch64TTIImpl::getScalarizationOverhead(
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
TTI::TargetCostKind CostKind) {
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
if (isa<ScalableVectorType>(Ty))
return InstructionCost::getInvalid();
if (Ty->getElementType()->isFloatingPointTy())
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind);
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {});

/// Return the cost of the scaling factor used in the addressing
/// mode represented by AM for this target, for a load/store
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ static unsigned isM1OrSmaller(MVT VT) {

InstructionCost RISCVTTIImpl::getScalarizationOverhead(
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
TTI::TargetCostKind CostKind) {
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
if (isa<ScalableVectorType>(Ty))
return InstructionCost::getInvalid();

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind);
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {});

InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
TTI::TargetCostKind CostKind);
Expand Down
76 changes: 60 additions & 16 deletions llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,42 @@ bool SystemZTTIImpl::hasDivRemOp(Type *DataType, bool IsSigned) {
return (VT.isScalarInteger() && TLI->isTypeLegal(VT));
}

static bool isFreeEltLoad(Value *Op) {
if (isa<LoadInst>(Op) && Op->hasOneUse()) {
const Instruction *UserI = cast<Instruction>(*Op->user_begin());
return !isa<StoreInst>(UserI); // Prefer MVC
}
return false;
}

InstructionCost SystemZTTIImpl::getScalarizationOverhead(
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
unsigned NumElts = cast<FixedVectorType>(Ty)->getNumElements();
InstructionCost Cost = 0;

if (Insert && Ty->isIntOrIntVectorTy(64)) {
// VLVGP will insert two GPRs with one instruction, while VLE will load
// an element directly with no extra cost
assert((VL.empty() || VL.size() == NumElts) &&
"Type does not match the number of values.");
InstructionCost CurrVectorCost = 0;
for (unsigned Idx = 0; Idx < NumElts; ++Idx) {
if (DemandedElts[Idx] && !(VL.size() && isFreeEltLoad(VL[Idx])))
++CurrVectorCost;
if (Idx % 2 == 1) {
Cost += std::min(InstructionCost(1), CurrVectorCost);
CurrVectorCost = 0;
}
}
Insert = false;
}

Cost += BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
CostKind, VL);
return Cost;
}

// Return the bit size for the scalar type or vector element
// type. getScalarSizeInBits() returns 0 for a pointer type.
static unsigned getScalarSizeInBits(Type *Ty) {
Expand Down Expand Up @@ -609,7 +645,7 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
if (DivRemConst) {
SmallVector<Type *> Tys(Args.size(), Ty);
return VF * DivMulSeqCost +
getScalarizationOverhead(VTy, Args, Tys, CostKind);
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
}
if ((SignedDivRem || UnsignedDivRem) && VF > 4)
// Temporary hack: disable high vectorization factors with integer
Expand All @@ -636,7 +672,7 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
SmallVector<Type *> Tys(Args.size(), Ty);
InstructionCost Cost =
(VF * ScalarCost) +
getScalarizationOverhead(VTy, Args, Tys, CostKind);
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
// FIXME: VF 2 for these FP operations are currently just as
// expensive as for VF 4.
if (VF == 2)
Expand All @@ -654,8 +690,9 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
// There is no native support for FRem.
if (Opcode == Instruction::FRem) {
SmallVector<Type *> Tys(Args.size(), Ty);
InstructionCost Cost = (VF * LIBCALL_COST) +
getScalarizationOverhead(VTy, Args, Tys, CostKind);
InstructionCost Cost =
(VF * LIBCALL_COST) +
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
// FIXME: VF 2 for float is currently just as expensive as for VF 4.
if (VF == 2 && ScalarBits == 32)
Cost *= 2;
Expand Down Expand Up @@ -975,10 +1012,10 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
(Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI))
NeedsExtracts = false;

TotCost += getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
NeedsExtracts, CostKind);
TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts,
/*Extract*/ false, CostKind);
TotCost += BaseT::getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
NeedsExtracts, CostKind);
TotCost += BaseT::getScalarizationOverhead(DstVecTy, NeedsInserts,
/*Extract*/ false, CostKind);

// FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4.
if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32)
Expand All @@ -990,8 +1027,8 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
if (Opcode == Instruction::FPTrunc) {
if (SrcScalarBits == 128) // fp128 -> double/float + inserts of elements.
return VF /*ldxbr/lexbr*/ +
getScalarizationOverhead(DstVecTy, /*Insert*/ true,
/*Extract*/ false, CostKind);
BaseT::getScalarizationOverhead(DstVecTy, /*Insert*/ true,
/*Extract*/ false, CostKind);
else // double -> float
return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/);
}
Expand All @@ -1004,8 +1041,8 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
return VF * 2;
}
// -> fp128. VF * lxdb/lxeb + extraction of elements.
return VF + getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
/*Extract*/ true, CostKind);
return VF + BaseT::getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
/*Extract*/ true, CostKind);
}
}

Expand Down Expand Up @@ -1114,10 +1151,17 @@ InstructionCost SystemZTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index, Value *Op0,
Value *Op1) {
// vlvgp will insert two grs into a vector register, so only count half the
// number of instructions.
if (Opcode == Instruction::InsertElement && Val->isIntOrIntVectorTy(64))
return ((Index % 2 == 0) ? 1 : 0);
if (Opcode == Instruction::InsertElement) {
// Vector Element Load.
if (Op1 != nullptr && isFreeEltLoad(Op1))
return 0;

// vlvgp will insert two grs into a vector register, so count half the
// number of instructions as an estimate when we don't have the full
// picture (as in getScalarizationOverhead()).
if (Val->isIntOrIntVectorTy(64))
return ((Index % 2 == 0) ? 1 : 0);
}

if (Opcode == Instruction::ExtractElement) {
int Cost = ((getScalarSizeInBits(Val) == 1) ? 2 /*+test-under-mask*/ : 1);
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
bool hasDivRemOp(Type *DataType, bool IsSigned);
bool prefersVectorizedAddressing() { return false; }
bool LSRWithInstrQueries() { return true; }
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {});
bool supportsEfficientVectorElementLoadStore() { return true; }
bool enableInterleavedAccessVectorization() { return true; }

Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4854,10 +4854,9 @@ InstructionCost X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
RegisterFileMoveCost;
}

InstructionCost
X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind) {
InstructionCost X86TTIImpl::getScalarizationOverhead(
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
assert(DemandedElts.getBitWidth() ==
cast<FixedVectorType>(Ty)->getNumElements() &&
"Vector size mismatch");
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/X86/X86TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind);
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {});
InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
int VF,
const APInt &DemandedDstElts,
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3110,9 +3110,8 @@ class BoUpSLP {
SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries,
unsigned NumParts, bool ForOrder = false);

/// \returns the scalarization cost for this list of values. Assuming that
/// this subtree gets vectorized, we may need to extract the values from the
/// roots. This method calculates the cost of extracting the values.
/// \returns the cost of gathering (inserting) the values in \p VL into a
/// vector.
/// \param ForPoisonSrc true if initial vector is poison, false otherwise.
InstructionCost getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
Type *ScalarTy) const;
Expand Down Expand Up @@ -13498,9 +13497,10 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
TTI::SK_InsertSubvector, VecTy, std::nullopt, CostKind,
I * ScalarTyNumElements, cast<FixedVectorType>(ScalarTy));
} else {
Cost = TTI->getScalarizationOverhead(VecTy, ~ShuffledElements,
Cost = TTI->getScalarizationOverhead(VecTy,
/*DemandedElts*/ ~ShuffledElements,
/*Insert*/ true,
/*Extract*/ false, CostKind);
/*Extract*/ false, CostKind, VL);
}
}
if (DuplicateNonConst)
Expand Down
Loading

0 comments on commit 0ad6be1

Please sign in to comment.