Skip to content

Commit

Permalink
[SCEV][LAA] Re-commit r260085 and r260086, this time with a fix for t…
Browse files Browse the repository at this point in the history
…he memory

sanitizer issue. The PredicatedScalarEvolution's copy constructor
wasn't copying the Generation value, and was leaving it un-initialized.

Original commit message:

[SCEV][LAA] Add no wrap SCEV predicates and use use them to improve strided pointer detection

Summary:
This change adds no wrap SCEV predicates with:
  - support for runtime checking
  - support for expression rewriting:
      (sext ({x,+,y}) -> {sext(x),+,sext(y)}
      (zext ({x,+,y}) -> {zext(x),+,sext(y)}

Note that we are sign extending the increment of the SCEV, even for
the zext case. This is needed to cover the fairly common case where y would
be a (small) negative integer. In order to do this, this change adds two new
flags: nusw and nssw that are applicable to AddRecExprs and permit the
transformations above.

We also change isStridedPtr in LAA to be able to make use of
these predicates. With this feature we should now always be able to
work around overflow issues in the dependence analysis.

Reviewers: mzolotukhin, sanjoy, anemet

Subscribers: mzolotukhin, sanjoy, llvm-commits, rengolin, jmolloy, hfinkel

Differential Revision: http://reviews.llvm.org/D15412



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@260112 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
sbaranga-arm committed Feb 8, 2016
1 parent c44fb84 commit e942cf8
Show file tree
Hide file tree
Showing 10 changed files with 728 additions and 41 deletions.
4 changes: 3 additions & 1 deletion include/llvm/Analysis/LoopAccessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,10 @@ const SCEV *replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
///
/// If necessary this method will version the stride of the pointer according
/// to \p PtrToStride and therefore add a new predicate to \p Preds.
/// The \p Assume parameter indicates if we are allowed to make additional
/// run-time assumptions.
int isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, const Loop *Lp,
const ValueToValueMap &StridesMap);
const ValueToValueMap &StridesMap, bool Assume = false);

/// \brief Returns true if the memory operations \p A and \p B are consecutive.
/// This is a simple API that does not depend on the analysis pass.
Expand Down
124 changes: 120 additions & 4 deletions include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "llvm/IR/Operator.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/IR/ValueMap.h"
#include "llvm/Pass.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/DataTypes.h"
Expand Down Expand Up @@ -179,7 +180,7 @@ namespace llvm {
FoldingSetNodeIDRef FastID;

public:
enum SCEVPredicateKind { P_Union, P_Equal };
enum SCEVPredicateKind { P_Union, P_Equal, P_Wrap };

protected:
SCEVPredicateKind Kind;
Expand Down Expand Up @@ -269,6 +270,98 @@ namespace llvm {
}
};

/// SCEVWrapPredicate - This class represents an assumption
/// made on an AddRec expression. Given an affine AddRec expression
/// {a,+,b}, we assume that it has the nssw or nusw flags (defined
/// below).
class SCEVWrapPredicate final : public SCEVPredicate {
public:
/// Similar to SCEV::NoWrapFlags, but with slightly different semantics
/// for FlagNUSW. The increment is considered to be signed, and a + b
/// (where b is the increment) is considered to wrap if:
/// zext(a + b) != zext(a) + sext(b)
///
/// If Signed is a function that takes an n-bit tuple and maps to the
/// integer domain as the tuples value interpreted as twos complement,
/// and Unsigned a function that takes an n-bit tuple and maps to the
/// integer domain as as the base two value of input tuple, then a + b
/// has IncrementNUSW iff:
///
/// 0 <= Unsigned(a) + Signed(b) < 2^n
///
/// The IncrementNSSW flag has identical semantics with SCEV::FlagNSW.
///
/// Note that the IncrementNUSW flag is not commutative: if base + inc
/// has IncrementNUSW, then inc + base doesn't neccessarily have this
/// property. The reason for this is that this is used for sign/zero
/// extending affine AddRec SCEV expressions when a SCEVWrapPredicate is
/// assumed. A {base,+,inc} expression is already non-commutative with
/// regards to base and inc, since it is interpreted as:
/// (((base + inc) + inc) + inc) ...
enum IncrementWrapFlags {
IncrementAnyWrap = 0, // No guarantee.
IncrementNUSW = (1 << 0), // No unsigned with signed increment wrap.
IncrementNSSW = (1 << 1), // No signed with signed increment wrap
// (equivalent with SCEV::NSW)
IncrementNoWrapMask = (1 << 2) - 1
};

/// Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT
clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags,
SCEVWrapPredicate::IncrementWrapFlags OffFlags) {
assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!");
assert((OffFlags & IncrementNoWrapMask) == OffFlags &&
"Invalid flags value!");
return (SCEVWrapPredicate::IncrementWrapFlags)(Flags & ~OffFlags);
}

static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT
maskFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, int Mask) {
assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!");
assert((Mask & IncrementNoWrapMask) == Mask && "Invalid mask value!");

return (SCEVWrapPredicate::IncrementWrapFlags)(Flags & Mask);
}

static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT
setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags,
SCEVWrapPredicate::IncrementWrapFlags OnFlags) {
assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!");
assert((OnFlags & IncrementNoWrapMask) == OnFlags &&
"Invalid flags value!");

return (SCEVWrapPredicate::IncrementWrapFlags)(Flags | OnFlags);
}

/// \brief Returns the set of SCEVWrapPredicate no wrap flags implied
/// by a SCEVAddRecExpr.
static SCEVWrapPredicate::IncrementWrapFlags
getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE);

private:
const SCEVAddRecExpr *AR;
IncrementWrapFlags Flags;

public:
explicit SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
const SCEVAddRecExpr *AR,
IncrementWrapFlags Flags);

/// \brief Returns the set assumed no overflow flags.
IncrementWrapFlags getFlags() const { return Flags; }
/// Implementation of the SCEVPredicate interface
const SCEV *getExpr() const override;
bool implies(const SCEVPredicate *N) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;

/// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const SCEVPredicate *P) {
return P->getKind() == P_Wrap;
}
};

/// SCEVUnionPredicate - This class represents a composition of other
/// SCEV predicates, and is the class that most clients will interact with.
/// This is equivalent to a logical "AND" of all the predicates in the union.
Expand Down Expand Up @@ -1280,8 +1373,18 @@ namespace llvm {
const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
const SCEVConstant *RHS);

const SCEVPredicate *
getWrapPredicate(const SCEVAddRecExpr *AR,
SCEVWrapPredicate::IncrementWrapFlags AddedFlags);

/// Re-writes the SCEV according to the Predicates in \p Preds.
const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A);
const SCEV *rewriteUsingPredicate(const SCEV *Scev, const Loop *L,
SCEVUnionPredicate &A);
/// Tries to convert the \p Scev expression to an AddRec expression,
/// adding additional predicates to \p Preds as required.
const SCEV *convertSCEVToAddRecWithPredicates(const SCEV *Scev,
const Loop *L,
SCEVUnionPredicate &Preds);

private:
/// Compute the backedge taken count knowing the interval difference, the
Expand Down Expand Up @@ -1372,7 +1475,7 @@ namespace llvm {
/// - lowers the number of expression rewrites.
class PredicatedScalarEvolution {
public:
PredicatedScalarEvolution(ScalarEvolution &SE);
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L);
const SCEVUnionPredicate &getUnionPredicate() const;
/// \brief Returns the SCEV expression of V, in the context of the current
/// SCEV predicate.
Expand All @@ -1382,9 +1485,18 @@ namespace llvm {
const SCEV *getSCEV(Value *V);
/// \brief Adds a new predicate.
void addPredicate(const SCEVPredicate &Pred);
/// \brief Attempts to produce an AddRecExpr for V by adding additional
/// SCEV predicates.
const SCEV *getAsAddRec(Value *V);
/// \brief Proves that V doesn't overflow by adding SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags);
/// \brief Returns true if we've proved that V doesn't wrap by means of a
/// SCEV predicate.
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags);
/// \brief Returns the ScalarEvolution analysis used.
ScalarEvolution *getSE() const { return &SE; }

/// We need to explicitly define the copy constructor because of FlagsMap.
PredicatedScalarEvolution(const PredicatedScalarEvolution&);
private:
/// \brief Increments the version number of the predicate.
/// This needs to be called every time the SCEV predicate changes.
Expand All @@ -1398,8 +1510,12 @@ namespace llvm {
/// rewrites, we will rewrite the previous result instead of the original
/// SCEV.
DenseMap<const SCEV *, RewriteEntry> RewriteMap;
/// Records what NoWrap flags we've added to a Value *.
ValueMap<Value *, SCEVWrapPredicate::IncrementWrapFlags> FlagsMap;
/// The ScalarEvolution analysis.
ScalarEvolution &SE;
/// The analyzed Loop.
const Loop &L;
/// The SCEVPredicate that forms our context. We will rewrite all
/// expressions assuming that this predicate true.
SCEVUnionPredicate Preds;
Expand Down
9 changes: 9 additions & 0 deletions include/llvm/Analysis/ScalarEvolutionExpander.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ namespace llvm {
Value *expandEqualPredicate(const SCEVEqualPredicate *Pred,
Instruction *Loc);

/// \brief Generates code that evaluates if the \p AR expression will
/// overflow.
Value *generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc,
bool Signed);

/// \brief A specialized variant of expandCodeForPredicate, handling the
/// case when we are expanding code for a SCEVWrapPredicate.
Value *expandWrapPredicate(const SCEVWrapPredicate *P, Instruction *Loc);

/// \brief A specialized variant of expandCodeForPredicate, handling the
/// case when we are expanding code for a SCEVUnionPredicate.
Value *expandUnionPredicate(const SCEVUnionPredicate *Pred,
Expand Down
61 changes: 43 additions & 18 deletions lib/Analysis/LoopAccessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ static bool isInBoundsGep(Value *Ptr) {
/// \brief Return true if an AddRec pointer \p Ptr is unsigned non-wrapping,
/// i.e. monotonically increasing/decreasing.
static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
ScalarEvolution *SE, const Loop *L) {
PredicatedScalarEvolution &PSE, const Loop *L) {
// FIXME: This should probably only return true for NUW.
if (AR->getNoWrapFlags(SCEV::NoWrapMask))
return true;
Expand Down Expand Up @@ -809,7 +809,7 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
// Assume constant for other the operand so that the AddRec can be
// easily found.
isa<ConstantInt>(OBO->getOperand(1))) {
auto *OpScev = SE->getSCEV(OBO->getOperand(0));
auto *OpScev = PSE.getSCEV(OBO->getOperand(0));

if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpScev))
return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW);
Expand All @@ -820,31 +820,35 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,

/// \brief Check whether the access through \p Ptr has a constant stride.
int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr,
const Loop *Lp, const ValueToValueMap &StridesMap) {
const Loop *Lp, const ValueToValueMap &StridesMap,
bool Assume) {
Type *Ty = Ptr->getType();
assert(Ty->isPointerTy() && "Unexpected non-ptr");

// Make sure that the pointer does not point to aggregate types.
auto *PtrTy = cast<PointerType>(Ty);
if (PtrTy->getElementType()->isAggregateType()) {
DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type"
<< *Ptr << "\n");
DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" << *Ptr
<< "\n");
return 0;
}

const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr);

const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
if (Assume && !AR)
AR = dyn_cast<SCEVAddRecExpr>(PSE.getAsAddRec(Ptr));

if (!AR) {
DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer "
<< *Ptr << " SCEV: " << *PtrScev << "\n");
DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr
<< " SCEV: " << *PtrScev << "\n");
return 0;
}

// The accesss function must stride over the innermost loop.
if (Lp != AR->getLoop()) {
DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop " <<
*Ptr << " SCEV: " << *PtrScev << "\n");
*Ptr << " SCEV: " << *AR << "\n");
return 0;
}

Expand All @@ -856,12 +860,23 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr,
// to access the pointer value "0" which is undefined behavior in address
// space 0, therefore we can also vectorize this case.
bool IsInBoundsGEP = isInBoundsGep(Ptr);
bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, PSE.getSE(), Lp);
bool IsNoWrapAddRec =
PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) ||
isNoWrapAddRec(Ptr, AR, PSE, Lp);
bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0;
if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) {
DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space "
<< *Ptr << " SCEV: " << *PtrScev << "\n");
return 0;
if (Assume) {
PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
IsNoWrapAddRec = true;
DEBUG(dbgs() << "LAA: Pointer may wrap in the address space:\n"
<< "LAA: Pointer: " << *Ptr << "\n"
<< "LAA: SCEV: " << *AR << "\n"
<< "LAA: Added an overflow assumption\n");
} else {
DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space "
<< *Ptr << " SCEV: " << *AR << "\n");
return 0;
}
}

// Check the step is constant.
Expand All @@ -871,7 +886,7 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr,
const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
if (!C) {
DEBUG(dbgs() << "LAA: Bad stride - Not a constant strided " << *Ptr <<
" SCEV: " << *PtrScev << "\n");
" SCEV: " << *AR << "\n");
return 0;
}

Expand All @@ -895,8 +910,18 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr,
// know we can't "wrap around the address space". In case of address space
// zero we know that this won't happen without triggering undefined behavior.
if (!IsNoWrapAddRec && (IsInBoundsGEP || IsInAddressSpaceZero) &&
Stride != 1 && Stride != -1)
return 0;
Stride != 1 && Stride != -1) {
if (Assume) {
// We can avoid this case by adding a run-time check.
DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either "
<< "inbouds or in address space 0 may wrap:\n"
<< "LAA: Pointer: " << *Ptr << "\n"
<< "LAA: SCEV: " << *AR << "\n"
<< "LAA: Added an overflow assumption\n");
PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
} else
return 0;
}

return Stride;
}
Expand Down Expand Up @@ -1123,8 +1148,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
const SCEV *AScev = replaceSymbolicStrideSCEV(PSE, Strides, APtr);
const SCEV *BScev = replaceSymbolicStrideSCEV(PSE, Strides, BPtr);

int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides);
int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides);
int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides, true);
int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides, true);

const SCEV *Src = AScev;
const SCEV *Sink = BScev;
Expand Down Expand Up @@ -1824,7 +1849,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
const TargetLibraryInfo *TLI, AliasAnalysis *AA,
DominatorTree *DT, LoopInfo *LI,
const ValueToValueMap &Strides)
: PSE(*SE), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL),
: PSE(*SE, *L), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL),
TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
MaxSafeDepDistBytes(-1U), CanVecMem(false),
StoreToLoopInvariantAddress(false) {
Expand Down
Loading

0 comments on commit e942cf8

Please sign in to comment.