Skip to content

Commit

Permalink
[RewriteStatepointsForGC] Extend base pointer inference to handle ins…
Browse files Browse the repository at this point in the history
…ertelement

This change is simply enhancing the existing inference algorithm to handle insertelement instructions by conservatively inserting a new instruction to propagate the vector of associated base pointers. In the process, I'm ripping out the peephole optimizations which mostly helped cover the fact this hadn't been done.

Note that most of the newly inserted nodes will be nearly immediately removed by the post insertion optimization pass introduced in 246718. Arguably, we should be trying harder to avoid the malloc traffic here, but I'd rather get the code correct, then worry about compile time.

Unlike previous extensions of the algorithm to handle more case, I discovered the existing code was causing miscompiles in some cases. In particular, we had an implicit assumption that the peephole covered *all* insert element instructions, so if we had a value directly based on a insert element the peephole didn't cover, we proceeded as if it were a base anyways. Not good. I believe we had the same issue with shufflevector which is why I adjusted the predicate for them as well.

Differential Revision: http://reviews.llvm.org/D12583



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@247210 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
preames committed Sep 9, 2015
1 parent b4f6a50 commit b04fde3
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 59 deletions.
119 changes: 61 additions & 58 deletions lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I);
/// If the later, the return pointer is a BDV (or possibly a base) for the
/// particular element in 'I'.
static BaseDefiningValueResult
findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
findBaseDefiningValueOfVector(Value *I) {
assert(I->getType()->isVectorTy() &&
cast<VectorType>(I->getType())->getElementType()->isPointerTy() &&
"Illegal to ask for the base pointer of a non-pointer type");
Expand Down Expand Up @@ -362,35 +362,12 @@ findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {

if (isa<LoadInst>(I))
return BaseDefiningValueResult(I, true);

// For an insert element, we might be able to look through it if we know
// something about the indexes.
if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(I)) {
if (Index) {
Value *InsertIndex = IEI->getOperand(2);
// This index is inserting the value, look for its BDV
if (InsertIndex == Index)
return findBaseDefiningValue(IEI->getOperand(1));
// Both constant, and can't be equal per above. This insert is definitely
// not relevant, look back at the rest of the vector and keep trying.
if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex))
return findBaseDefiningValueOfVector(IEI->getOperand(0), Index);
}

// If both inputs to the insertelement are known bases, then so is the
// insertelement itself. NOTE: This should be handled within the generic
// base pointer inference code and after http://reviews.llvm.org/D12583,
// will be. However, when strengthening asserts I needed to add this to
// keep an existing test passing which was 'working'. FIXME
if (findBaseDefiningValue(IEI->getOperand(0)).IsKnownBase &&
findBaseDefiningValue(IEI->getOperand(1)).IsKnownBase)
return BaseDefiningValueResult(IEI, true);

if (isa<InsertElementInst>(I))
// We don't know whether this vector contains entirely base pointers or
// not. To be conservatively correct, we treat it as a BDV and will
// duplicate code as needed to construct a parallel vector of bases.
return BaseDefiningValueResult(IEI, false);
}
return BaseDefiningValueResult(I, false);

if (isa<ShuffleVectorInst>(I))
// We don't know whether this vector contains entirely base pointers or
Expand Down Expand Up @@ -528,27 +505,11 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) {
// We may need to insert a parallel instruction to extract the appropriate
// element out of the base vector corresponding to the input. Given this,
// it's analogous to the phi and select case even though it's not a merge.
if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
Value *VectorOperand = EEI->getVectorOperand();
Value *Index = EEI->getIndexOperand();
auto VecResult = findBaseDefiningValueOfVector(VectorOperand, Index);
Value *VectorBase = VecResult.BDV;
if (VectorBase->getType()->isPointerTy())
// We found a BDV for this specific element with the vector. This is an
// optimization, but in practice it covers most of the useful cases
// created via scalarization. Note: The peephole optimization here is
// currently needed for correctness since the general algorithm doesn't
// yet handle insertelements. That will change shortly.
return BaseDefiningValueResult(VectorBase, VecResult.IsKnownBase);
else {
assert(VectorBase->getType()->isVectorTy());
// Otherwise, we have an instruction which potentially produces a
// derived pointer and we need findBasePointers to clone code for us
// such that we can create an instruction which produces the
// accompanying base pointer.
return BaseDefiningValueResult(I, VecResult.IsKnownBase);
}
}
if (isa<ExtractElementInst>(I))
// Note: There a lot of obvious peephole cases here. This are deliberately
// handled after the main base pointer inference algorithm to make writing
// test cases to exercise that code easier.
return BaseDefiningValueResult(I, false);

// The last two cases here don't return a base pointer. Instead, they
// return a value which dynamically selects from among several base
Expand Down Expand Up @@ -587,7 +548,9 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
/// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
/// is it known to be a base pointer? Or do we need to continue searching.
static bool isKnownBaseResult(Value *V) {
if (!isa<PHINode>(V) && !isa<SelectInst>(V) && !isa<ExtractElementInst>(V)) {
if (!isa<PHINode>(V) && !isa<SelectInst>(V) &&
!isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) &&
!isa<ShuffleVectorInst>(V)) {
// no recursion possible
return true;
}
Expand Down Expand Up @@ -755,7 +718,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {

#ifndef NDEBUG
auto isExpectedBDVType = [](Value *BDV) {
return isa<PHINode>(BDV) || isa<SelectInst>(BDV) || isa<ExtractElementInst>(BDV);
return isa<PHINode>(BDV) || isa<SelectInst>(BDV) ||
isa<ExtractElementInst>(BDV) || isa<InsertElementInst>(BDV);
};
#endif

Expand Down Expand Up @@ -795,10 +759,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
visitIncomingValue(Sel->getFalseValue());
} else if (auto *EE = dyn_cast<ExtractElementInst>(Current)) {
visitIncomingValue(EE->getVectorOperand());
} else if (auto *IE = dyn_cast<InsertElementInst>(Current)) {
visitIncomingValue(IE->getOperand(0)); // vector operand
visitIncomingValue(IE->getOperand(1)); // scalar operand
} else {
// There are two classes of instructions we know we don't handle.
assert(isa<ShuffleVectorInst>(Current) ||
isa<InsertElementInst>(Current));
// There is one known class of instructions we know we don't handle.
assert(isa<ShuffleVectorInst>(Current));
llvm_unreachable("unimplemented instruction case");
}
}
Expand Down Expand Up @@ -849,11 +815,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
} else if (PHINode *Phi = dyn_cast<PHINode>(v)) {
for (Value *Val : Phi->incoming_values())
calculateMeet.meetWith(getStateForInput(Val));
} else {
} else if (auto *EE = dyn_cast<ExtractElementInst>(v)) {
// The 'meet' for an extractelement is slightly trivial, but it's still
// useful in that it drives us to conflict if our input is.
auto *EE = cast<ExtractElementInst>(v);
calculateMeet.meetWith(getStateForInput(EE->getVectorOperand()));
} else {
// Given there's a inherent type mismatch between the operands, will
// *always* produce Conflict.
auto *IE = cast<InsertElementInst>(v);
calculateMeet.meetWith(getStateForInput(IE->getOperand(0)));
calculateMeet.meetWith(getStateForInput(IE->getOperand(1)));
}

BDVState oldState = states[v];
Expand Down Expand Up @@ -899,6 +870,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
states[I] = BDVState(BDVState::Base, BaseInst);
}

// Since we're joining a vector and scalar base, they can never be the
// same. As a result, we should always see insert element having reached
// the conflict state.
if (isa<InsertElementInst>(I)) {
assert(State.isConflict());
}

if (!State.isConflict())
continue;
Expand All @@ -920,14 +898,22 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
(I->getName() + ".base").str() : "base_select";
return SelectInst::Create(Sel->getCondition(), Undef,
Undef, Name, Sel);
} else {
auto *EE = cast<ExtractElementInst>(I);
} else if (auto *EE = dyn_cast<ExtractElementInst>(I)) {
UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType());
std::string Name = I->hasName() ?
(I->getName() + ".base").str() : "base_ee";
return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name,
EE);
} else {
auto *IE = cast<InsertElementInst>(I);
UndefValue *VecUndef = UndefValue::get(IE->getOperand(0)->getType());
UndefValue *ScalarUndef = UndefValue::get(IE->getOperand(1)->getType());
std::string Name = I->hasName() ?
(I->getName() + ".base").str() : "base_ie";
return InsertElementInst::Create(VecUndef, ScalarUndef,
IE->getOperand(2), Name, IE);
}

};
Instruction *BaseInst = MakeBaseInstPlaceholder(I);
// Add metadata marking this as a base value
Expand Down Expand Up @@ -1029,14 +1015,31 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
Value *Base = getBaseForInput(InVal, BaseSel);
BaseSel->setOperand(i, Base);
}
} else {
auto *BaseEE = cast<ExtractElementInst>(state.getBase());
} else if (auto *BaseEE = dyn_cast<ExtractElementInst>(state.getBase())) {
Value *InVal = cast<ExtractElementInst>(v)->getVectorOperand();
// Find the instruction which produces the base for each input. We may
// need to insert a bitcast.
Value *Base = getBaseForInput(InVal, BaseEE);
BaseEE->setOperand(0, Base);
} else {
auto *BaseIE = cast<InsertElementInst>(state.getBase());
auto *BdvIE = cast<InsertElementInst>(v);
auto UpdateOperand = [&](int OperandIdx) {
Value *InVal = BdvIE->getOperand(OperandIdx);
Value *Base = findBaseOrBDV(InVal, cache);
if (!isKnownBaseResult(Base)) {
// Either conflict or base.
assert(states.count(Base));
Base = states[Base].getBase();
assert(Base != nullptr && "unknown BDVState!");
}
assert(Base && "can't be null");
BaseIE->setOperand(OperandIdx, Base);
};
UpdateOperand(0); // vector operand
UpdateOperand(1); // scalar operand
}

}

// Now that we're done with the algorithm, see if we can optimize the
Expand Down
84 changes: 83 additions & 1 deletion test/Transforms/RewriteStatepointsForGC/base-vector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ entry:
; CHECK: extractelement
; CHECK: statepoint
; CHECK: gc.relocate
; CHECK-DAG: ; (%ptr, %obj)
; CHECK-DAG: (%obj, %obj)
%safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0)
ret i64 addrspace(1)* %obj
}
Expand All @@ -80,6 +80,88 @@ entry:
ret i64 addrspace(1)* %obj
}

declare void @use(i64 addrspace(1)*)

; When we can optimize an extractelement from a known
; index and avoid introducing new base pointer instructions
define void @test5(i1 %cnd, i64 addrspace(1)* %obj)
gc "statepoint-example" {
; CHECK-LABEL: @test5
; CHECK: gc.relocate
; CHECK-DAG: (%obj, %bdv)
entry:
%gep = getelementptr i64, i64 addrspace(1)* %obj, i64 1
%vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %gep, i32 0
%bdv = extractelement <2 x i64 addrspace(1)*> %vec, i32 0
%safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 5, i32 0, i32 -1, i32 0, i32 0, i32 0)
call void @use(i64 addrspace(1)* %bdv)
ret void
}

; When we fundementally have to duplicate
define void @test6(i1 %cnd, i64 addrspace(1)* %obj, i64 %idx)
gc "statepoint-example" {
; CHECK-LABEL: @test6
; CHECK: %gep = getelementptr i64, i64 addrspace(1)* %obj, i64 1
; CHECK: %vec.base = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %obj, i32 0, !is_base_value !0
; CHECK: %vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %gep, i32 0
; CHECK: %bdv.base = extractelement <2 x i64 addrspace(1)*> %vec.base, i64 %idx, !is_base_value !0
; CHECK: %bdv = extractelement <2 x i64 addrspace(1)*> %vec, i64 %idx
; CHECK: gc.statepoint
; CHECK: gc.relocate
; CHECK-DAG: (%bdv.base, %bdv)
entry:
%gep = getelementptr i64, i64 addrspace(1)* %obj, i64 1
%vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %gep, i32 0
%bdv = extractelement <2 x i64 addrspace(1)*> %vec, i64 %idx
%safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 5, i32 0, i32 -1, i32 0, i32 0, i32 0)
call void @use(i64 addrspace(1)* %bdv)
ret void
}

; A more complicated example involving vector and scalar bases.
; This is derived from a failing test case when we didn't have correct
; insertelement handling.
define i64 addrspace(1)* @test7(i1 %cnd, i64 addrspace(1)* %obj,
i64 addrspace(1)* %obj2)
gc "statepoint-example" {
; CHECK-LABEL: @test7
entry:
%vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %obj2, i32 0
br label %merge1
merge1:
; CHECK-LABEL: merge1:
; CHECK: vec2.base
; CHECK: vec2
; CHECK: gep
; CHECK: vec3.base
; CHECK: vec3
%vec2 = phi <2 x i64 addrspace(1)*> [ %vec, %entry ], [ %vec3, %merge1 ]
%gep = getelementptr i64, i64 addrspace(1)* %obj2, i64 1
%vec3 = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %gep, i32 0
br i1 %cnd, label %merge1, label %next1
next1:
; CHECK-LABEL: next1:
; CHECK: bdv.base =
; CHECK: bdv =
%bdv = extractelement <2 x i64 addrspace(1)*> %vec2, i32 0
br label %merge
merge:
; CHECK-LABEL: merge:
; CHECK: %objb.base
; CHECK: %objb
; CHECK: gc.statepoint
; CHECK: gc.relocate
; CHECK-DAG: (%objb.base, %objb)

%objb = phi i64 addrspace(1)* [ %obj, %next1 ], [ %bdv, %merge ]
br i1 %cnd, label %merge, label %next
next:
%safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 5, i32 0, i32 -1, i32 0, i32 0, i32 0)
ret i64 addrspace(1)* %objb
}


declare void @do_safepoint()

declare i32 @llvm.experimental.gc.statepoint.p0f_isVoidf(i64, i32, void ()*, i32, i32, ...)

0 comments on commit b04fde3

Please sign in to comment.