Skip to content

Commit

Permalink
effects phase: avoid multiple iterations on nested loops if the forwa…
Browse files Browse the repository at this point in the history
…rd end states are stable
  • Loading branch information
davleopo committed Nov 21, 2016
1 parent 821b2bf commit 94a1ccd
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ public final class GraalOptions {
@Option(help = "", type = OptionType.Debug)
public static final OptionValue<Boolean> OptReadElimination = new OptionValue<>(true);

@Option(help = "", type = OptionType.Debug)
public static final OptionValue<Integer> ReadEliminationMaxLoopVisits = new OptionValue<>(5);

@Option(help = "", type = OptionType.Debug)
public static final OptionValue<Boolean> OptDeoptimizationGrouping = new OptionValue<>(true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import static com.oracle.graal.phases.common.DeadCodeEliminationPhase.Optionality.Optional;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;

Expand All @@ -42,65 +43,92 @@
import com.oracle.graal.phases.common.CanonicalizerPhase;
import com.oracle.graal.phases.common.DeadCodeEliminationPhase;
import com.oracle.graal.phases.common.inlining.InliningUtil;
import com.oracle.graal.phases.schedule.SchedulePhase;
import com.oracle.graal.phases.tiers.HighTierContext;
import com.oracle.graal.virtual.phases.ea.EarlyReadEliminationPhase;
import com.oracle.graal.virtual.phases.ea.PartialEscapePhase;

import jdk.vm.ci.meta.ResolvedJavaMethod;
import sun.misc.Unsafe;

public class RecursiveInliningTest extends GraalCompilerTest {

public static int SideEffectI;
public static int[] Memory = new int[]{1, 2};

public static final Unsafe UNSAFE;
static {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
UNSAFE = (Unsafe) theUnsafe.get(Unsafe.class);
} catch (Exception e) {
throw new RuntimeException("Exception while trying to get Unsafe", e);
}
}

public static void recursiveLoopMethodUnsafeLoad(int a) {
if (UNSAFE.getInt(Memory, (long) Unsafe.ARRAY_LONG_BASE_OFFSET) == 0) {
return;
}
for (int i = 0; i < a; i++) {
recursiveLoopMethodUnsafeLoad(i);
}
}

public static void recursiveLoopMethodSlow(int a) {
public static void recursiveLoopMethodFieldLoad(int a) {
if (SideEffectI == 0) {
return;
}
for (int i = 0; i < a; i++) {
recursiveLoopMethodSlow(i);
recursiveLoopMethodFieldLoad(i);
}
}

public static void recursiveLoopMethodFast(int a) {
public static void recursiveLoopMethod(int a) {
if (a == 0) {
return;
}
for (int i = 0; i < a; i++) {
recursiveLoopMethodFast(i);
recursiveLoopMethod(i);
}
}

public static int IterationsStart = 8/* Increase to escalate early read elimination and PEA */;
public static int IterationsEnd = 22/* Increase to escalate early read elimination and PEA */;
public static final boolean LOG = true;

@Test
public void inlineDirectRecursiveLoopCallFast() {
for (int i = IterationsStart; i < IterationsEnd; i++) {
StructuredGraph graph = getGraph("recursiveLoopMethodFast", i);
long elapsed = runAndTimeEarlyReadEliminationPhase(graph);
System.out.printf("Needed %dms to run early read elimination on a graph with %d recursive inlined calls of method %s\n", elapsed, i, graph.method());
}
for (int i = IterationsStart; i < IterationsEnd; i++) {
StructuredGraph graph = getGraph("recursiveLoopMethodFast", i);
long elapsed = runAndTimePartialEscapeAnalysis(graph);
System.out.printf("Needed %dms to run early partial escape analysis on a graph with %d recursive inlined calls of method %s\n", elapsed, i, graph.method());
}
public static int IterationsStart = 1;
public static int IterationsEnd = 128;

@Test(timeout = 120_000)
public void inlineDirectRecursiveLoopCallUnsafeLoad() {
testAndTime("recursiveLoopMethodUnsafeLoad");
}

@Test(timeout = 120_000)
public void inlineDirectRecursiveLoopCallFieldLoad() {
testAndTime("recursiveLoopMethodFieldLoad");
}

@Test(timeout = 120_000)
public void inlineDirectRecursiveLoopCallNoReads() {
testAndTime("recursiveLoopMethod");
}

@Test
public void inlineDirectRecursiveLoopCallSlow() {
private void testAndTime(String snippet) {
for (int i = IterationsStart; i < IterationsEnd; i++) {
StructuredGraph graph = getGraph("recursiveLoopMethodSlow", i);
StructuredGraph graph = getGraph(snippet, i);
long elapsed = runAndTimeEarlyReadEliminationPhase(graph);
System.out.printf("Needed %dms to run early read elimination on a graph with %d recursive inlined calls of method %s\n", elapsed, i, graph.method());
if (LOG) {
System.out.printf("Needed %dms to run early read elimination on a graph with %d recursive inlined calls of method %s\n", elapsed, i, graph.method());
}
}
for (int i = IterationsStart; i < IterationsEnd; i++) {
StructuredGraph graph = getGraph("recursiveLoopMethodSlow", i);
StructuredGraph graph = getGraph(snippet, i);
long elapsed = runAndTimePartialEscapeAnalysis(graph);
System.out.printf("Needed %dms to run early partial escape analysis on a graph with %d recursive inlined calls of method %s\n", elapsed, i, graph.method());
if (LOG) {
System.out.printf("Needed %dms to run early partial escape analysis on a graph with %d recursive inlined calls of method %s\n", elapsed, i, graph.method());
}
}

}

private long runAndTimePartialEscapeAnalysis(StructuredGraph g) {
Expand All @@ -123,6 +151,7 @@ private long runAndTimeEarlyReadEliminationPhase(StructuredGraph g) {
return end - start;
}

@SuppressWarnings("try")
private StructuredGraph getGraph(final String snippet, int nrOfInlinings) {
try (Scope s = Debug.scope("RecursiveInliningTest", new DebugDumpScope(snippet, true))) {
ResolvedJavaMethod callerMethod = getResolvedJavaMethod(snippet);
Expand All @@ -140,7 +169,7 @@ private StructuredGraph getGraph(final String snippet, int nrOfInlinings) {
canonicalizer.applyIncremental(callerGraph, context, canonicalizeNodes);
Debug.dump(Debug.BASIC_LOG_LEVEL, callerGraph, "After inlining %s into %s iteration %d", calleeMethod, callerMethod, i);
}

new SchedulePhase().apply(callerGraph);
return callerGraph;
} catch (Throwable e) {
throw Debug.handle(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@
package com.oracle.graal.virtual.phases.ea;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.oracle.graal.compiler.common.CollectionsFactory;
import com.oracle.graal.compiler.common.LocationIdentity;
import com.oracle.graal.compiler.common.cfg.BlockMap;
import com.oracle.graal.compiler.common.cfg.Loop;
import com.oracle.graal.compiler.common.type.Stamp;
import com.oracle.graal.debug.Debug;
import com.oracle.graal.debug.GraalError;
import com.oracle.graal.debug.Indent;
import com.oracle.graal.graph.Node;
import com.oracle.graal.graph.NodeBitMap;
import com.oracle.graal.graph.NodeMap;
Expand Down Expand Up @@ -68,6 +72,8 @@ public abstract class EffectsClosure<BlockT extends EffectsBlockState<BlockT>> e
protected final NodeMap<ValueNode> aliases;
protected final BlockMap<GraphEffectList> blockEffects;
private final Map<Loop<Block>, GraphEffectList> loopMergeEffects = CollectionsFactory.newIdentityMap();
// Intended to be used by read-eliminating phases based on the effects phase.
protected final Map<Loop<Block>, LoopKillCache> loopLocationKillCache = CollectionsFactory.newIdentityMap();
private final Map<LoopBeginNode, BlockT> loopEntryStates = Node.newIdentityMap();
private final NodeBitMap hasScalarReplacedInputs;

Expand Down Expand Up @@ -220,6 +226,7 @@ protected BlockT merge(Block merge, List<BlockT> states) {
}

@Override
@SuppressWarnings("try")
protected final List<BlockT> processLoop(Loop<Block> loop, BlockT initialState) {
if (initialState.isDead()) {
ArrayList<BlockT> states = new ArrayList<>();
Expand All @@ -228,45 +235,73 @@ protected final List<BlockT> processLoop(Loop<Block> loop, BlockT initialState)
}
return states;
}

BlockT loopEntryState = initialState;
BlockT lastMergedState = cloneState(initialState);
/*
* Special case nested loops: To avoid an exponential runtime for nested loops we try to
* only process them as little times as possible.
*
* In the first iteration of an outer most loop we go into the inner most loop(s). We run
* the first iteration of the inner most loop and then, if necessary, a second iteration.
*
* We return from the recursion and finish the first iteration of the outermost loop. If we
* have to do a second iteration in the outer most loop we go again into the inner most
* loop(s) but this time we already know all states that are killed by the loop so inside
* the loop we will only have those changes that propagate from the first iteration of the
* outer most loop into the current loop. We strip the initial loop state for the inner most
* loops and do the first iteration with the (possible) changes from outer loops. If there
* are no changes we only have to do 1 iteration and are done.
*
*/
BlockT initialStateRemovedKilledLocations = stripKilledLoopLocations(loop, cloneState(initialState));
BlockT loopEntryState = initialStateRemovedKilledLocations;
BlockT lastMergedState = cloneState(initialStateRemovedKilledLocations);
processInitialLoopState(loop, lastMergedState);
MergeProcessor mergeProcessor = createMergeProcessor(loop.getHeader());
for (int iteration = 0; iteration < 10; iteration++) {
LoopInfo<BlockT> info = ReentrantBlockIterator.processLoop(this, loop, cloneState(lastMergedState));

List<BlockT> states = new ArrayList<>();
states.add(initialState);
states.addAll(info.endStates);
doMergeWithoutDead(mergeProcessor, states);

Debug.log("================== %s", loop.getHeader());
Debug.log("%s", mergeProcessor.newState);
Debug.log("===== vs.");
Debug.log("%s", lastMergedState);

if (mergeProcessor.newState.equivalentTo(lastMergedState)) {
mergeProcessor.commitEnds(states);

blockEffects.get(loop.getHeader()).insertAll(mergeProcessor.mergeEffects, 0);
loopMergeEffects.put(loop, mergeProcessor.afterMergeEffects);

assert info.exitStates.size() == loop.getExits().size();
loopEntryStates.put((LoopBeginNode) loop.getHeader().getBeginNode(), loopEntryState);
assert assertExitStatesNonEmpty(loop, info);

return info.exitStates;
} else {
lastMergedState = mergeProcessor.newState;
for (Block block : loop.getBlocks()) {
blockEffects.get(block).clear();
try (Indent i = Debug.logAndIndent("================== Process Loop Effects Closure: block:%s begin node:%s", loop.getHeader(), loop.getHeader().getBeginNode())) {
LoopInfo<BlockT> info = ReentrantBlockIterator.processLoop(this, loop, cloneState(lastMergedState));

List<BlockT> states = new ArrayList<>();
states.add(initialStateRemovedKilledLocations);
states.addAll(info.endStates);
doMergeWithoutDead(mergeProcessor, states);

Debug.log("MergeProcessor New State: %s", mergeProcessor.newState);
Debug.log("===== vs.");
Debug.log("Last Merged State: %s", lastMergedState);

if (mergeProcessor.newState.equivalentTo(lastMergedState)) {
mergeProcessor.commitEnds(states);

blockEffects.get(loop.getHeader()).insertAll(mergeProcessor.mergeEffects, 0);
loopMergeEffects.put(loop, mergeProcessor.afterMergeEffects);

assert info.exitStates.size() == loop.getExits().size();
loopEntryStates.put((LoopBeginNode) loop.getHeader().getBeginNode(), loopEntryState);
assert assertExitStatesNonEmpty(loop, info);

processKilledLoopLocations(loop, initialStateRemovedKilledLocations, mergeProcessor.newState);
return info.exitStates;
} else {
lastMergedState = mergeProcessor.newState;
for (Block block : loop.getBlocks()) {
blockEffects.get(block).clear();
}
}
}
}
throw new GraalError("too many iterations at %s", loop);
}

@SuppressWarnings("unused")
protected BlockT stripKilledLoopLocations(Loop<Block> loop, BlockT initialState) {
return initialState;
}

@SuppressWarnings("unused")
protected void processKilledLoopLocations(Loop<Block> loop, BlockT initialState, BlockT mergedStates) {
// nothing to do
}

@SuppressWarnings("unused")
protected void processInitialLoopState(Loop<Block> loop, BlockT initialState) {
// nothing to do
Expand Down Expand Up @@ -408,4 +443,64 @@ public ValueNode getScalarAlias(ValueNode node) {
ValueNode result = aliases.get(node);
return (result == null || result instanceof VirtualObjectNode) ? node : result;
}

protected static class LoopKillCache {
private int visits;
private LocationIdentity firstLocation;
private Set<LocationIdentity> killedLocations;
private boolean killsAll;

protected LoopKillCache(int visits) {
this.visits = visits;
}

protected void visited() {
visits++;
}

protected int visits() {
return visits;
}

protected void setKillsAll() {
killsAll = true;
firstLocation = null;
killedLocations = null;
}

protected boolean containsLocation(LocationIdentity locationIdentity) {
if (killsAll) {
return true;
}
if (firstLocation == null) {
return false;
}
if (!firstLocation.equals(locationIdentity)) {
return killedLocations != null ? killedLocations.contains(locationIdentity) : false;
}
return true;
}

protected void rememberLoopKilledLocation(LocationIdentity locationIdentity) {
if (killsAll) {
return;
}
if (firstLocation == null || firstLocation.equals(locationIdentity)) {
firstLocation = locationIdentity;
} else {
if (killedLocations == null) {
killedLocations = new HashSet<>();
}
killedLocations.add(locationIdentity);
}
}

protected boolean loopKillsLocations() {
if (killsAll) {
return true;
}
return firstLocation != null;
}
}

}
Loading

0 comments on commit 94a1ccd

Please sign in to comment.