Skip to content

Commit

Permalink
Improve conditional elimination for bounds checks
Browse files Browse the repository at this point in the history
* Try to merge some deopt actions in conditional elimination
* Improve `<.getSucceedingStampForX` in unsigned cases
  • Loading branch information
gilles-duboscq committed Feb 2, 2018
1 parent 04fe5cc commit ee9b171
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ private Stamp join0(Stamp otherStamp, boolean improve) {
boolean joinExactType = exactType || other.exactType;
if (Objects.equals(type, other.type)) {
joinType = type;
} else if (type == null && other.type == null) {
joinType = null;
} else if (type == null) {
joinType = other.type;
} else if (other.type == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public abstract class AbstractFixedGuardNode extends DeoptimizingFixedWithNextNo

public static final NodeClass<AbstractFixedGuardNode> TYPE = NodeClass.create(AbstractFixedGuardNode.class);
@Input(InputType.Condition) protected LogicNode condition;
protected final DeoptimizationReason reason;
protected final DeoptimizationAction action;
protected DeoptimizationReason reason;
protected DeoptimizationAction action;
protected JavaConstant speculation;
protected boolean negated;

Expand Down Expand Up @@ -134,4 +134,14 @@ public DeoptimizeNode lowerToIf() {
public boolean canDeoptimize() {
return true;
}

@Override
public void setAction(DeoptimizationAction action) {
this.action = action;
}

@Override
public void setReason(DeoptimizationReason reason) {
this.reason = reason;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public final class DeoptimizeNode extends AbstractDeoptimizeNode implements Lowe
public static final int DEFAULT_DEBUG_ID = 0;

public static final NodeClass<DeoptimizeNode> TYPE = NodeClass.create(DeoptimizeNode.class);
protected final DeoptimizationAction action;
protected final DeoptimizationReason reason;
protected DeoptimizationAction action;
protected DeoptimizationReason reason;
protected int debugId;
protected final JavaConstant speculation;

Expand Down Expand Up @@ -72,11 +72,21 @@ public DeoptimizationAction getAction() {
return action;
}

@Override
public void setAction(DeoptimizationAction action) {
this.action = action;
}

@Override
public DeoptimizationReason getReason() {
return reason;
}

@Override
public void setReason(DeoptimizationReason reason) {
this.reason = reason;
}

@Override
public void lower(LoweringTool tool) {
tool.getLowerer().lower(this, tool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ public class GuardNode extends FloatingAnchoredNode implements Canonicalizable,

public static final NodeClass<GuardNode> TYPE = NodeClass.create(GuardNode.class);
@Input(Condition) protected LogicNode condition;
protected final DeoptimizationReason reason;
protected JavaConstant speculation;
protected DeoptimizationReason reason;
protected DeoptimizationAction action;
protected JavaConstant speculation;
protected boolean negated;

public GuardNode(LogicNode condition, AnchoringNode anchor, DeoptimizationReason reason, DeoptimizationAction action, boolean negated, JavaConstant speculation) {
Expand Down Expand Up @@ -149,7 +149,13 @@ public void negate() {
negated = !negated;
}

@Override
public void setAction(DeoptimizationAction invalidaterecompile) {
this.action = invalidaterecompile;
}

@Override
public void setReason(DeoptimizationReason reason) {
this.reason = reason;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ public interface StaticDeoptimizingNode extends ValueNodeInterface {

DeoptimizationReason getReason();

void setReason(DeoptimizationReason reason);

DeoptimizationAction getAction();

void setAction(DeoptimizationAction action);

JavaConstant getSpeculation();

/**
Expand Down Expand Up @@ -75,4 +79,15 @@ default GuardPriority computePriority() {
}
throw GraalError.shouldNotReachHere();
}

static DeoptimizationAction mergeActions(DeoptimizationAction a1, DeoptimizationAction a2) {
if (a1 == a2) {
return a1;
}
if (a1 == DeoptimizationAction.InvalidateRecompile && a2 == DeoptimizationAction.InvalidateReprofile ||
a1 == DeoptimizationAction.InvalidateReprofile && a2 == DeoptimizationAction.InvalidateRecompile) {
return DeoptimizationAction.InvalidateReprofile;
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*/
package org.graalvm.compiler.nodes.calc;

import static jdk.vm.ci.code.CodeUtil.mask;

import org.graalvm.compiler.core.common.calc.CanonicalCondition;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.core.common.type.Stamp;
Expand Down Expand Up @@ -278,7 +280,7 @@ protected IntegerStamp getSucceedingStampForX(IntegerStamp xStamp, IntegerStamp
}
low += 1;
}
if (compare(low, lowerBound(xStamp)) > 0) {
if (compare(low, lowerBound(xStamp)) > 0 || upperBound(xStamp) != (xStamp.upperBound() & mask(xStamp.getBits()))) {
return forInteger(bits, low, upperBound(xStamp));
}
} else {
Expand All @@ -290,7 +292,7 @@ protected IntegerStamp getSucceedingStampForX(IntegerStamp xStamp, IntegerStamp
}
low -= 1;
}
if (compare(low, upperBound(xStamp)) < 0) {
if (compare(low, upperBound(xStamp)) < 0 || lowerBound(xStamp) != (xStamp.lowerBound() & mask(xStamp.getBits()))) {
return forInteger(bits, lowerBound(xStamp), low);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*/
package org.graalvm.compiler.phases.common;

import static org.graalvm.compiler.nodes.StaticDeoptimizingNode.mergeActions;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
Expand All @@ -30,6 +32,7 @@
import org.graalvm.collections.Equivalence;
import org.graalvm.collections.MapCursor;
import org.graalvm.collections.Pair;
import org.graalvm.compiler.core.common.cfg.AbstractControlFlowGraph;
import org.graalvm.compiler.core.common.cfg.BlockMap;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp;
Expand Down Expand Up @@ -92,6 +95,7 @@
import org.graalvm.compiler.phases.schedule.SchedulePhase.SchedulingStrategy;
import org.graalvm.compiler.phases.tiers.PhaseContext;

import jdk.vm.ci.meta.DeoptimizationAction;
import jdk.vm.ci.meta.JavaConstant;
import jdk.vm.ci.meta.TriState;

Expand Down Expand Up @@ -146,8 +150,8 @@ protected BlockMap<List<Node>> getBlockToNodes(@SuppressWarnings("unused") Contr
}

protected ControlFlowGraph.RecursiveVisitor<?> createVisitor(StructuredGraph graph, @SuppressWarnings("unused") ControlFlowGraph cfg, BlockMap<List<Node>> blockToNodes,
@SuppressWarnings("unused") NodeMap<Block> nodeToBlock, PhaseContext context) {
return new Instance(graph, blockToNodes, context);
NodeMap<Block> nodeToBlock, PhaseContext context) {
return new Instance(graph, blockToNodes, nodeToBlock, context);
}

public static class MoveGuardsUpwards implements ControlFlowGraph.RecursiveVisitor<Block> {
Expand Down Expand Up @@ -244,6 +248,7 @@ public InfoElement get(EndNode end) {
public static class Instance implements ControlFlowGraph.RecursiveVisitor<Integer> {
protected final NodeMap<InfoElement> map;
protected final BlockMap<List<Node>> blockToNodes;
protected final NodeMap<Block> nodeToBlock;
protected final CanonicalizerTool tool;
protected final NodeStack undoOperations;
protected final StructuredGraph graph;
Expand All @@ -255,10 +260,11 @@ public static class Instance implements ControlFlowGraph.RecursiveVisitor<Intege
*/
private Deque<DeoptimizingGuard> pendingTests;

public Instance(StructuredGraph graph, BlockMap<List<Node>> blockToNodes, PhaseContext context) {
public Instance(StructuredGraph graph, BlockMap<List<Node>> blockToNodes, NodeMap<Block> nodeToBlock, PhaseContext context) {
this.graph = graph;
this.debug = graph.getDebug();
this.blockToNodes = blockToNodes;
this.nodeToBlock = nodeToBlock;
this.undoOperations = new NodeStack();
this.map = graph.createNodeMap();
pendingTests = new ArrayDeque<>();
Expand Down Expand Up @@ -614,7 +620,7 @@ private static Stamp getSafeStamp(ValueNode x) {
* never be replaced with a pi node via canonicalization).
*/
private static Stamp getOtherSafeStamp(ValueNode x) {
if (x.isConstant()) {
if (x.isConstant() || x.graph().isAfterFixedReadPhase()) {
return x.stamp(NodeView.DEFAULT);
}
return x.stamp(NodeView.DEFAULT).unrestricted();
Expand All @@ -633,6 +639,23 @@ Pair<InfoElement, Stamp> recursiveFoldStampFromInfo(Node node) {
return recursiveFoldStamp(node);
}

/**
* Look for a preceding guard whose condition is implied by {@code thisGuard}. If we find
* one, try to move this guard just above that preceding guard so that we can fold it:
*
* <pre>
* guard(C1); // preceding guard
* ...
* guard(C2); // thisGuard
* </pre>
*
* If C2 => C1, transform to:
*
* <pre>
* guard(C2);
* ...
* </pre>
*/
protected boolean foldPendingTest(DeoptimizingGuard thisGuard, ValueNode original, Stamp newStamp, GuardRewirer rewireGuardFunction) {
for (DeoptimizingGuard pendingGuard : pendingTests) {
LogicNode pendingCondition = pendingGuard.getCondition();
Expand Down Expand Up @@ -661,21 +684,41 @@ protected boolean foldPendingTest(DeoptimizingGuard thisGuard, ValueNode origina
if (result.isKnown()) {
/*
* The test case be folded using the information available but the test can only
* be moved up if we're sure there's no schedule dependence. For now limit it to
* the original node and constants.
* be moved up if we're sure there's no schedule dependence.
*/
InputFilter v = new InputFilter(original);
thisGuard.getCondition().applyInputs(v);
if (v.ok && foldGuard(thisGuard, pendingGuard, result.toBoolean(), newStamp, rewireGuardFunction)) {
if (canScheduleAbove(thisGuard.getCondition(), pendingGuard.asNode(), original) && foldGuard(thisGuard, pendingGuard, result.toBoolean(), newStamp, rewireGuardFunction)) {
return true;
}
}
}
return false;
}

private boolean canScheduleAbove(Node n, Node target, ValueNode knownToBeAbove) {
Block targetBlock = nodeToBlock.get(target);
Block testBlock = nodeToBlock.get(n);
if (targetBlock != null && testBlock != null) {
if (targetBlock == testBlock) {
for (FixedNode fixed : targetBlock.getNodes()) {
if (fixed == n) {
return true;
} else if (fixed == target) {
return false;
}
}
} else if (AbstractControlFlowGraph.dominates(testBlock, targetBlock)) {
return true;
}
return false;
}
InputFilter v = new InputFilter(knownToBeAbove);
n.applyInputs(v);
return v.ok;
}

protected boolean foldGuard(DeoptimizingGuard thisGuard, DeoptimizingGuard otherGuard, boolean outcome, Stamp guardedValueStamp, GuardRewirer rewireGuardFunction) {
if (otherGuard.getAction() == thisGuard.getAction() && otherGuard.getSpeculation() == thisGuard.getSpeculation()) {
DeoptimizationAction action = mergeActions(otherGuard.getAction(), thisGuard.getAction());
if (action != null && otherGuard.getSpeculation() == thisGuard.getSpeculation()) {
LogicNode condition = (LogicNode) thisGuard.getCondition().copyWithInputs();
/*
* We have ...; guard(C1); guard(C2);...
Expand All @@ -688,12 +731,16 @@ protected boolean foldGuard(DeoptimizingGuard thisGuard, DeoptimizingGuard other
*
* - If C2 => !C1, `mustDeopt` is true and we transform to ..; guard(C1); deopt;
*/
// for the second case, the action of the deopt is copied from there:
thisGuard.setAction(action);
GuardRewirer rewirer = (guard, result, innerGuardedValueStamp, newInput) -> {
// `result` is `outcome`, `guard` is `otherGuard`
boolean mustDeopt = result == otherGuard.isNegated();
if (rewireGuardFunction.rewire(guard, mustDeopt == thisGuard.isNegated(), innerGuardedValueStamp, newInput)) {
if (!mustDeopt) {
otherGuard.setCondition(condition, thisGuard.isNegated());
otherGuard.setAction(action);
otherGuard.setReason(thisGuard.getReason());
}
return true;
}
Expand Down Expand Up @@ -783,16 +830,6 @@ protected boolean tryProveGuardCondition(DeoptimizingGuard thisGuard, LogicNode
}
} else if (node instanceof BinaryOpLogicNode) {
BinaryOpLogicNode binaryOpLogicNode = (BinaryOpLogicNode) node;
infoElement = getInfoElements(binaryOpLogicNode);
while (infoElement != null) {
if (infoElement.getStamp().equals(StampFactory.contradiction())) {
return rewireGuards(infoElement.getGuard(), false, infoElement.getProxifiedInput(), null, rewireGuardFunction);
} else if (infoElement.getStamp().equals(StampFactory.tautology())) {
return rewireGuards(infoElement.getGuard(), true, infoElement.getProxifiedInput(), null, rewireGuardFunction);
}
infoElement = nextElement(infoElement);
}

ValueNode x = binaryOpLogicNode.getX();
ValueNode y = binaryOpLogicNode.getY();
infoElement = getInfoElements(x);
Expand Down

0 comments on commit ee9b171

Please sign in to comment.