Skip to content

Commit

Permalink
[GR-19646] Vectorized LLVM getelementptr.
Browse files Browse the repository at this point in the history
PullRequest: graal/4940
rschatz committed Dec 4, 2019
2 parents cad026e + 22b04f7 commit ab8fe0c
Showing 19 changed files with 567 additions and 155 deletions.
Original file line number Diff line number Diff line change
@@ -198,6 +198,7 @@
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMNativeVarargsAreaStackAllocationNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMStructByValueNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMVarArgCompoundAddressNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMVectorizedGetElementPtrNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.NativeMemSetNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.NativeProfiledMemMoveNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.ProtectReadOnlyGlobalsBlockNode;
@@ -731,8 +732,13 @@ public LLVMLoadNode createExtractValue(Type type, LLVMExpressionNode targetAddre
}

@Override
public LLVMExpressionNode createTypedElementPointer(LLVMExpressionNode aggregateAddress, LLVMExpressionNode index, long indexedTypeLength, Type targetType) {
return LLVMGetElementPtrNodeGen.create(aggregateAddress, index, indexedTypeLength, targetType);
public LLVMExpressionNode createTypedElementPointer(long indexedTypeLength, Type targetType, LLVMExpressionNode aggregateAddress, LLVMExpressionNode index) {
return LLVMGetElementPtrNodeGen.create(indexedTypeLength, targetType, aggregateAddress, index);
}

@Override
public LLVMExpressionNode createVectorizedTypedElementPointer(long indexedTypeLength, Type targetType, LLVMExpressionNode aggregateAddress, LLVMExpressionNode index) {
return LLVMVectorizedGetElementPtrNodeGen.create(indexedTypeLength, targetType, aggregateAddress, index);
}

@Override
@@ -1523,12 +1529,12 @@ protected LLVMArithmeticNode createScalarArithmeticOp(ArithmeticOperation op, Ty
case X86_FP80:
return LLVMFP80ArithmeticNodeGen.create(op, left, right);
default:
throw new AssertionError(type);
throw new AssertionError("Unknown primitive type: " + type);
}
} else if (type instanceof VariableBitWidthType) {
return LLVMIVarBitArithmeticNodeGen.create(op, left, right);
} else {
throw new AssertionError(type);
throw new AssertionError("Unknown type: " + type);
}
}

Original file line number Diff line number Diff line change
@@ -778,23 +778,15 @@ private void createGetElementPointer(RecordBuffer buffer) {
int pointer = readIndex(buffer);
Type base = readValueType(buffer, pointer);
int[] indices = readIndices(buffer);
Type type;
if (base instanceof VectorType) {
VectorType vector = (VectorType) base;
type = new VectorType(new PointerType(getElementPointerType(vector.getElementType(), indices)), vector.getNumberOfElements());
} else {
type = new PointerType(getElementPointerType(base, indices));
}

Type type = getElementPointerType(base, indices);
emit(GetElementPointerInstruction.fromSymbols(scope.getSymbols(), type, pointer, indices, isInbounds));
}

private void createGetElementPointerOld(RecordBuffer buffer, boolean isInbounds) {
int pointer = readIndex(buffer);
Type base = readValueType(buffer, pointer);
int[] indices = readIndices(buffer);
Type type = new PointerType(getElementPointerType(base, indices));

Type type = getElementPointerType(base, indices);
emit(GetElementPointerInstruction.fromSymbols(scope.getSymbols(), type, pointer, indices, isInbounds));
}

@@ -902,8 +894,12 @@ private static int getAlign(long argument) {
}

private Type getElementPointerType(Type type, int[] indices) {
Type elementType = type;
boolean vectorized = type instanceof VectorType;
int length = vectorized ? ((VectorType) type).getNumberOfElements() : 0;
Type elementType = vectorized ? ((VectorType) type).getElementType() : type;
for (int indexIndex : indices) {
Type indexType = scope.getValueType(indexIndex);

if (elementType instanceof PointerType) {
elementType = ((PointerType) elementType).getPointeeType();
} else if (elementType instanceof ArrayType) {
@@ -912,7 +908,6 @@ private Type getElementPointerType(Type type, int[] indices) {
elementType = ((VectorType) elementType).getElementType();
} else if (elementType instanceof StructureType) {
StructureType structure = (StructureType) elementType;
Type indexType = scope.getValueType(indexIndex);
if (!(indexType instanceof PrimitiveType)) {
throw new LLVMParserException("Cannot infer structure element from " + indexType);
}
@@ -922,8 +917,26 @@ private Type getElementPointerType(Type type, int[] indices) {
} else {
throw new LLVMParserException("Cannot index type: " + elementType);
}

if (indexType instanceof VectorType) {
int indexVectorLength = ((VectorType) indexType).getNumberOfElements();
if (vectorized) {
if (indexVectorLength != length) {
throw new LLVMParserException(String.format("Vectors of different lengths are not supported: %d != %d", indexVectorLength, length));
}
} else {
vectorized = true;
length = indexVectorLength;
}
}
}

Type pointer = new PointerType(elementType);
if (vectorized) {
return new VectorType(pointer, length);
} else {
return pointer;
}
return elementType;
}

private int readIndex(RecordBuffer buffer) {
Original file line number Diff line number Diff line change
@@ -713,7 +713,7 @@ public void visit(ExtractValueInstruction extract) {

if (offset != 0) {
final LLVMExpressionNode oneLiteralNode = nodeFactory.createLiteral(1, PrimitiveType.I32);
targetAddress = nodeFactory.createTypedElementPointer(targetAddress, oneLiteralNode, offset, extract.getType());
targetAddress = nodeFactory.createTypedElementPointer(offset, extract.getType(), targetAddress, oneLiteralNode);
}

final LLVMExpressionNode result = nodeFactory.createExtractValue(resultType, targetAddress);
Original file line number Diff line number Diff line change
@@ -72,6 +72,9 @@
import com.oracle.truffle.llvm.runtime.except.LLVMParserException;
import com.oracle.truffle.llvm.runtime.global.LLVMGlobal;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMExpressionNode;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMVectorizedGetElementPtrNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMVectorizedGetElementPtrNodeGen.IndexVectorBroadcastNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMVectorizedGetElementPtrNodeGen.ResultVectorBroadcastNodeGen;
import com.oracle.truffle.llvm.runtime.pointer.LLVMManagedPointer;
import com.oracle.truffle.llvm.runtime.pointer.LLVMNativePointer;
import com.oracle.truffle.llvm.runtime.types.AggregateType;
@@ -479,6 +482,69 @@ public interface OptimizedResolver {
LLVMExpressionNode resolve(SymbolImpl symbol, int excludeOtherIndex, SymbolImpl other, SymbolImpl... others);
}

/**
* The rules for whether to build a scalar-getelementptr or vector-getelementptr node:
*
* S = scalar node
*
* V = vector node
*
* BC = broadcast node
*
* GEP = scalar getelementptr node
*
* VGEP = vector getelementptr node
*
* The BasePointer (BP) could either be a scalar, vector or GEP/VGEP node coming from the
* previous indexing dimension. The Index can only either be a scalar or a vector.
*
* Pointers, arrays and structures are considered scalars.
*
* (BP, Idx) --> Next OP(PTR, IDX)
*
* -------------------------------
*
* 0: (S, S) --> GEP(S, S)
*
* 1: (S, V) --> VGEP(BC(S), V)
*
* 2: (V, S) --> VGEP(V, BC(S))
*
* 3: (V, V) --> VGEP(V, V)
*
* 4: (GEP, S) --> GEP(GEP, S)
*
* 5: (GEP, V) --> VGEP(BC(GEP), V)
*
* 6: (VGEP, S) --> VGEP(VGEP, BC(S))
*
* 7: (VGEP, V) --> VGEP(VGEP, V)
*/
private LLVMExpressionNode createElementPointer(long indexedTypeLength, Type currentType, LLVMExpressionNode currentAddress, LLVMExpressionNode indexNode, Type indexType,
final boolean wasVectorized) {
if (wasVectorized) {
// Cases 2, 3, 6, 7
if (indexType instanceof VectorType) {
// Cases 3, 7
return nodeFactory.createVectorizedTypedElementPointer(indexedTypeLength, currentType, currentAddress, indexNode);
} else {
// Cases 2, 6
int length = ((VectorType) currentType).getNumberOfElements();
return nodeFactory.createVectorizedTypedElementPointer(indexedTypeLength, currentType, currentAddress, IndexVectorBroadcastNodeGen.create(length, indexNode));
}
} else {
// Cases 0, 1, 4, 5
if (indexType instanceof VectorType) {
// Cases 1, 5
int length = ((VectorType) indexType).getNumberOfElements();
return nodeFactory.createVectorizedTypedElementPointer(indexedTypeLength, currentType, ResultVectorBroadcastNodeGen.create(length, currentAddress), indexNode);
} else {
// Cases 0, 4
return nodeFactory.createTypedElementPointer(indexedTypeLength, currentType, currentAddress, indexNode);
}
}
}

/**
* Turns a base value and a list of indices into a list of "get element pointer" operations, and
* allows callers to intercept the resolution of values to nodes (used for frame slot
@@ -497,7 +563,13 @@ public LLVMExpressionNode resolveElementPointer(SymbolImpl base, SymbolImpl[] in
LLVMExpressionNode currentAddress = resolver.resolve(base, -1, null, indices);
Type currentType = base.getType();

for (int i = 0, indicesSize = indices.length; i < indicesSize; i++) {
boolean wasVectorized = currentType instanceof VectorType;
if (wasVectorized) {
VectorType vectorType = (VectorType) currentType;
currentType = vectorType.getElementType();
}

for (int i = 0; i < indices.length; i++) {
SymbolImpl indexSymbol = indices[i];
Type indexType = indexSymbol.getType();

@@ -511,7 +583,8 @@ public LLVMExpressionNode resolveElementPointer(SymbolImpl base, SymbolImpl[] in
AggregateType aggregate = (AggregateType) currentType;
long indexedTypeLength = aggregate.getOffsetOf(1, dataLayout);
currentType = aggregate.getElementType(1);
currentAddress = nodeFactory.createTypedElementPointer(currentAddress, indexNodes[i], indexedTypeLength, currentType);
currentAddress = createElementPointer(indexedTypeLength, currentType, currentAddress, indexNodes[i], indexType, wasVectorized);
wasVectorized = currentAddress instanceof LLVMVectorizedGetElementPtrNodeGen;
} else {
// the index is a constant integer
AggregateType aggregate = (AggregateType) currentType;
@@ -520,7 +593,7 @@ public LLVMExpressionNode resolveElementPointer(SymbolImpl base, SymbolImpl[] in

// creating a pointer inserts type information, this needs to happen for the address
// computed by getelementptr even if it is the same as the basepointer
if (addressOffset != 0 || i == indicesSize - 1) {
if (addressOffset != 0 || i == indices.length - 1) {
LLVMExpressionNode indexNode;
if (indexType == PrimitiveType.I32) {
indexNode = nodeFactory.createLiteral(1, PrimitiveType.I32);
@@ -529,7 +602,8 @@ public LLVMExpressionNode resolveElementPointer(SymbolImpl base, SymbolImpl[] in
} else {
throw new AssertionError(indexType);
}
currentAddress = nodeFactory.createTypedElementPointer(currentAddress, indexNode, addressOffset, currentType);
currentAddress = createElementPointer(addressOffset, currentType, currentAddress, indexNode, indexType, wasVectorized);
wasVectorized = currentAddress instanceof LLVMVectorizedGetElementPtrNodeGen;
}
}
}
Original file line number Diff line number Diff line change
@@ -104,7 +104,9 @@ LLVMControlFlowNode createFunctionInvoke(FrameSlot resultLocation, LLVMExpressio

LLVMExpressionNode createExtractValue(Type type, LLVMExpressionNode targetAddress);

LLVMExpressionNode createTypedElementPointer(LLVMExpressionNode aggregateAddress, LLVMExpressionNode index, long indexedTypeLength, Type targetType);
LLVMExpressionNode createTypedElementPointer(long indexedTypeLength, Type targetType, LLVMExpressionNode aggregateAddress, LLVMExpressionNode index);

LLVMExpressionNode createVectorizedTypedElementPointer(long indexedTypeLength, Type targetType, LLVMExpressionNode aggregateAddress, LLVMExpressionNode index);

LLVMExpressionNode createSelect(Type type, LLVMExpressionNode condition, LLVMExpressionNode trueValue, LLVMExpressionNode falseValue);

Original file line number Diff line number Diff line change
@@ -41,14 +41,12 @@
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.llvm.runtime.nodes.intrinsics.interop.LLVMReadCharsetNode.LLVMCharset;
import com.oracle.truffle.llvm.runtime.nodes.intrinsics.llvm.LLVMIntrinsic;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMGetElementPtrNode.LLVMIncrementPointerNode;
import com.oracle.truffle.llvm.runtime.except.LLVMPolyglotException;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMStoreNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMExpressionNode;
import com.oracle.truffle.llvm.runtime.nodes.api.LLVMNode;
import com.oracle.truffle.llvm.runtime.nodes.intrinsics.interop.LLVMPolyglotAsStringNodeGen.EncodeStringNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.intrinsics.interop.LLVMPolyglotAsStringNodeGen.WriteStringNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMGetElementPtrNodeGen.LLVMIncrementPointerNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.store.LLVMI8StoreNodeGen;
import com.oracle.truffle.llvm.runtime.pointer.LLVMManagedPointer;
import com.oracle.truffle.llvm.runtime.pointer.LLVMPointer;
@@ -111,7 +109,6 @@ ByteBuffer doBoxed(Object object, LLVMCharset charset,

abstract static class WriteStringNode extends LLVMNode {

@Child private LLVMIncrementPointerNode inc = LLVMIncrementPointerNodeGen.create();
@Child private LLVMStoreNode write = LLVMI8StoreNodeGen.create(null, null);

protected abstract long execute(VirtualFrame frame, ByteBuffer source, Object target, long targetLen, int zeroTerminatorLen);
@@ -125,15 +122,15 @@ long doWrite(ByteBuffer srcBuffer, LLVMPointer target, long targetLen, int zeroT
LLVMPointer ptr = target;
while (source.hasRemaining() && bytesWritten < targetLen) {
write.executeWithTarget(ptr, source.get());
ptr = inc.executeWithTarget(ptr, Byte.BYTES);
ptr = ptr.increment(Byte.BYTES);
bytesWritten++;
}

long ret = bytesWritten;

for (int i = 0; i < zeroTerminatorLen && bytesWritten < targetLen; i++) {
write.executeWithTarget(ptr, (byte) 0);
ptr = inc.executeWithTarget(ptr, Byte.BYTES);
ptr = ptr.increment(Byte.BYTES);
bytesWritten++;
}

Original file line number Diff line number Diff line change
@@ -47,8 +47,6 @@
import com.oracle.truffle.llvm.runtime.nodes.intrinsics.interop.LLVMPolyglotFromStringNodeGen.ReadZeroTerminatedBytesNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.intrinsics.interop.LLVMReadCharsetNode.LLVMCharset;
import com.oracle.truffle.llvm.runtime.nodes.intrinsics.llvm.LLVMIntrinsic;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMGetElementPtrNode.LLVMIncrementPointerNode;
import com.oracle.truffle.llvm.runtime.nodes.memory.LLVMGetElementPtrNodeGen.LLVMIncrementPointerNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.load.LLVMI16LoadNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.load.LLVMI32LoadNodeGen;
import com.oracle.truffle.llvm.runtime.nodes.memory.load.LLVMI64LoadNodeGen;
@@ -88,7 +86,6 @@ abstract static class ReadBytesNode extends LLVMNode {
abstract static class ReadBytesWithLengthNode extends ReadBytesNode {

@Child private LLVMLoadNode load = LLVMI8LoadNodeGen.create(null);
@Child private LLVMIncrementPointerNode inc = LLVMIncrementPointerNodeGen.create();

@Specialization
ByteBuffer doRead(@SuppressWarnings("unused") LLVMCharset charset, LLVMPointer string, long len) {
@@ -97,7 +94,7 @@ ByteBuffer doRead(@SuppressWarnings("unused") LLVMCharset charset, LLVMPointer s
LLVMPointer ptr = string;
for (int i = 0; i < len; i++) {
byte value = (byte) load.executeWithTarget(ptr);
ptr = inc.executeWithTarget(ptr, Byte.BYTES);
ptr = ptr.increment(Byte.BYTES);
buffer.put(value);
}

@@ -112,8 +109,6 @@ abstract static class ReadZeroTerminatedBytesNode extends ReadBytesNode {

@CompilationFinal int bufferSize = 8;

@Child private LLVMIncrementPointerNode inc = LLVMIncrementPointerNodeGen.create();

@Specialization(limit = "4", guards = "charset.zeroTerminatorLen == increment")
ByteBuffer doRead(@SuppressWarnings("unused") LLVMCharset charset, LLVMPointer string,
@Cached("charset.zeroTerminatorLen") int increment,
@@ -126,7 +121,7 @@ ByteBuffer doRead(@SuppressWarnings("unused") LLVMCharset charset, LLVMPointer s
Object value;
do {
value = load.executeWithTarget(ptr);
ptr = inc.executeWithTarget(ptr, increment);
ptr = ptr.increment(increment);

if (result.remaining() < increment) {
// buffer overflow, allocate a bigger buffer
Loading

0 comments on commit ab8fe0c

Please sign in to comment.