Skip to content

Commit

Permalink
merge explode: fix eager loop exits on natural loop exit branch
Browse files Browse the repository at this point in the history
  • Loading branch information
davleopo committed Aug 18, 2020
1 parent 5f28d9b commit 58d9192
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,92 @@ private void findLoopExits(Loop loop) {
}
}
}
/*-
* Special case loop exits that merge on a common merge node. If the original, merge
* exploded loop, contains loop exit paths, where a loop exit path (a path already exiting
* the loop see loop exits vs natural loop exits) is already a natural one merge on a loop
* explosion merge, we run into troubles with phi nodes and proxy nodes.
*
* Consider the following piece of code outlining a loop exit path of a merge explode annotated loop
*
* <pre>
* // merge exploded loop
* mergeExplodedLoop: while(....)
* ...
* if(condition effectively exiting the loop) // natural loop exit
*
* break mergeExplodedLoop;
* ...
*
* // outerLoopContinueCode that uses values proxied inside the loop
* </pre>
*
*
* However, once the exit path contains control flow like below
* <pre>
* // merge exploded loop
* mergeExplodedLoop: while(....)
* ...
* if(condition effectively exiting the loop) // natural loop exit
* if(some unrelated condition) {
* ...
* } else {
* ...
* }
* break mergeExplodedLoop;
* ...
*
* // outerLoopContinueCode that uses values proxied inside the loop
* </pre>
*
* We would produce two loop exits that merge booth on the outerLoopContinueCode.
* This would require the generation of complex phi and proxy constructs, thus we include the merge inside the
* loop if we find a subsequent loop explosion merge.
*/
EconomicSet<MergeNode> merges = EconomicSet.create(Equivalence.IDENTITY_WITH_SYSTEM_HASHCODE);
EconomicSet<MergeNode> mergesToRemove = EconomicSet.create(Equivalence.IDENTITY_WITH_SYSTEM_HASHCODE);

for (AbstractEndNode end : loop.exits) {
AbstractMergeNode merge = end.merge();
assert merge instanceof MergeNode;
if (merges.contains((MergeNode) merge)) {
mergesToRemove.add((MergeNode) merge);
} else {
merges.add((MergeNode) merge);
}
}
merges.clear();
merges.addAll(mergesToRemove);
mergesToRemove.clear();
outer: for (MergeNode merge : merges) {
for (EndNode end : merge.ends) {
if (!loop.exits.contains(end)) {
continue outer;
}
}
mergesToRemove.add(merge);
}
// we found a shared merge as outlined above
if (mergesToRemove.size() > 0) {
assert merges.size() < loop.exits.size();
for (MergeNode merge : mergesToRemove) {
FixedNode current = merge;
while (current != null) {
if (current instanceof FixedWithNextNode) {
current = ((FixedWithNextNode) current).next();
continue;
}
if (current instanceof EndNode && methodScope.loopExplosionMerges.contains(((EndNode) current).merge())) {
// we found the place for the loop exit introduction since the subsequent
// merge has a frame state
loop.exits.removeIf(x -> x.merge() == merge);
loop.exits.add((EndNode) current);
break;
}
GraalError.shouldNotReachHere("Merge explode with complex exit branch: natural vs regular loop exit.");
}
}
}
}

private void insertLoopNodes(Loop loop) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,226 @@ public void testLoopControlVariableProxy() {

partialEval((OptimizedCallTarget) callee);
}

public static class WrongLoopExitMerge extends RootNode {
private final String name;
@CompilationFinal(dimensions = 1) private final byte[] bytecodes;
@CompilationFinal(dimensions = 1) private final FrameSlot[] locals;
@CompilationFinal(dimensions = 1) private final FrameSlot[] stack;

public WrongLoopExitMerge(String name, byte[] bytecodes, int maxLocals, int maxStack) {
super(null);
this.name = name;
this.bytecodes = bytecodes;
locals = new FrameSlot[maxLocals];
stack = new FrameSlot[maxStack];
for (int i = 0; i < maxLocals; ++i) {
locals[i] = this.getFrameDescriptor().addFrameSlot("local" + i);
this.getFrameDescriptor().setFrameSlotKind(locals[i], FrameSlotKind.Int);
}
for (int i = 0; i < maxStack; ++i) {
stack[i] = this.getFrameDescriptor().addFrameSlot("stack" + i);
this.getFrameDescriptor().setFrameSlotKind(stack[i], FrameSlotKind.Int);
}
}

protected void setInt(VirtualFrame frame, int stackIndex, int value) {
frame.setInt(stack[stackIndex], value);
}

protected void setBoolean(VirtualFrame frame, boolean value) {
frame.setBoolean(locals[0], value);
}

protected boolean getBoolean(VirtualFrame frame) {
try {
return frame.getBoolean(locals[0]);
} catch (FrameSlotTypeException e) {
return false;
}
}

protected int getInt(VirtualFrame frame, int stackIndex) {
try {
return frame.getInt(stack[stackIndex]);
} catch (FrameSlotTypeException e) {
throw new IllegalStateException("Error accessing stack slot " + stackIndex);
}
}

@Override
public String toString() {
return name;
}

public static int SideEffect;

@CompilationFinal int iterations = 2;

@Override
@ExplodeLoop(kind = LoopExplosionKind.MERGE_EXPLODE)
public Object execute(VirtualFrame frame) {
boolean result = false;
int topOfStack = -1;
int bci = 0;
boolean running = true;
outer: while (running) {
CompilerAsserts.partialEvaluationConstant(bci);
switch (bytecodes[bci]) {
case Bytecode.CONST: {
byte value = bytecodes[bci + 1];
topOfStack++;
setInt(frame, topOfStack, value);
bci = bci + 2;
continue;
}
case Bytecode.RETURN: {
running = false;
continue outer;
}
case Bytecode.ADD: {
int left = getInt(frame, topOfStack);
int right = getInt(frame, topOfStack - 1);
topOfStack--;
setInt(frame, topOfStack, left + right);
bci = bci + 1;
continue;
}
case Bytecode.IFZERO: {
int value = getInt(frame, topOfStack);
byte trueBci = bytecodes[bci + 1];
topOfStack--;
if (value == 0) {
bci = trueBci;
result = value == 0;
if (SideEffect == 42) {
GraalDirectives.sideEffect(result ? 12 : 14);
} else {
GraalDirectives.sideEffect(2);
}
// uncomment this fixes the code since we are no longer considering the
// merge after both branches be part of the loop explosion
// GraalDirectives.sideEffect(3);
} else {
bci = bci + 2;
}
continue;
}
case Bytecode.POP: {
getInt(frame, topOfStack);
topOfStack--;
bci++;
continue;
}
case Bytecode.JMP: {
byte newBci = bytecodes[bci + 1];
bci = newBci;
continue;
}
case Bytecode.DUP: {
int dupValue = getInt(frame, topOfStack);
topOfStack++;
setInt(frame, topOfStack, dupValue);
bci++;
continue;
}
}
}
return result;
}
}

public static class Caller extends RootNode {

@Child DirectCallNode callee;

protected Caller(CallTarget ct) {
super(null);
callee = DirectCallNode.create(ct);
callee.forceInlining();
}

@Override
@ExplodeLoop
public Object execute(VirtualFrame frame) {
Object o = callee.call(frame.getArguments());
if (!(o instanceof Boolean)) {
CompilerDirectives.transferToInterpreter();
}
boolean b = (boolean) o;
return b ? 0 : 10;
}

}

@Test
public void test01() {
byte[] bytecodes = new byte[]{
/* 0: */Bytecode.CONST,
/* 1: */42,
/* 2: */Bytecode.CONST,
/* 3: */-12,

// loop
/* 4: */Bytecode.CONST,
/* 5: */1,
/* 6: */Bytecode.ADD,
/* 7: */Bytecode.DUP,
/* 8: */Bytecode.IFZERO,
/* 9: */12,
// backedge
/* 10: */Bytecode.JMP,
/* 11: */4,

// loop exit
/* 12: */Bytecode.POP,
/* 13: */Bytecode.RETURN};

CallTarget callee = Truffle.getRuntime().createCallTarget(new WrongLoopExitMerge("mergedLoopExitProgram", bytecodes, 1, 3));
callee.call();
callee.call();
callee.call();
callee.call();

CallTarget caller = Truffle.getRuntime().createCallTarget(new Caller(callee));
caller.call();
caller.call();
caller.call();
caller.call();

partialEval((OptimizedCallTarget) caller);
}

@Test
public void test01Caller() {
byte[] bytecodes = new byte[]{
/* 0: */Bytecode.CONST,
/* 1: */42,
/* 2: */Bytecode.CONST,
/* 3: */-12,

// loop
/* 4: */Bytecode.CONST,
/* 5: */1,
/* 6: */Bytecode.ADD,
/* 7: */Bytecode.DUP,
/* 8: */Bytecode.IFZERO,
/* 9: */12,
// backedge
/* 10: */Bytecode.JMP,
/* 11: */4,

// loop exit
/* 12: */Bytecode.POP,
/* 13: */Bytecode.RETURN};

CallTarget callee = Truffle.getRuntime().createCallTarget(new WrongLoopExitMerge("mergedLoopExitProgram", bytecodes, 1, 3));
callee.call();
callee.call();
callee.call();
callee.call();

partialEval((OptimizedCallTarget) callee);
}

}

0 comments on commit 58d9192

Please sign in to comment.