Skip to content

Commit

Permalink
Add SqlMap get raw key and value blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Oct 12, 2023
1 parent 8127fd5 commit ba155d5
Show file tree
Hide file tree
Showing 42 changed files with 486 additions and 314 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ public interface MapAggregationState
default void merge(MapAggregationState other)
{
SqlMap serializedState = ((SingleMapAggregationState) other).removeTempSerializedState();
for (int i = 0; i < serializedState.getPositionCount(); i += 2) {
add(serializedState, i, serializedState, i + 1);
int rawOffset = serializedState.getRawOffset();
Block rawKeyBlock = serializedState.getRawKeyBlock();
Block rawValueBlock = serializedState.getRawValueBlock();

for (int i = 0; i < serializedState.getSize(); i++) {
add(rawKeyBlock, rawOffset + i, rawValueBlock, rawOffset + i);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.operator.aggregation;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.MapBlockBuilder;
import io.trino.spi.block.SqlMap;
Expand Down Expand Up @@ -40,8 +41,12 @@ public static void input(
@AggregationState({"K", "V"}) MapAggregationState state,
@SqlType("map(K,V)") SqlMap value)
{
for (int i = 0; i < value.getPositionCount(); i += 2) {
state.add(value, i, value, i + 1);
int rawOffset = value.getRawOffset();
Block rawKeyBlock = value.getRawKeyBlock();
Block rawValueBlock = value.getRawValueBlock();

for (int i = 0; i < value.getSize(); i++) {
state.add(rawKeyBlock, rawOffset + i, rawValueBlock, rawOffset + i);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ public interface HistogramState
default void merge(HistogramState other)
{
SqlMap serializedState = ((SingleHistogramState) other).removeTempSerializedState();
for (int i = 0; i < serializedState.getPositionCount(); i += 2) {
add(serializedState, i, BIGINT.getLong(serializedState, i + 1));
int rawOffset = serializedState.getRawOffset();
Block rawKeyBlock = serializedState.getRawKeyBlock();
Block rawValueBlock = serializedState.getRawValueBlock();

for (int i = 0; i < serializedState.getSize(); i++) {
add(rawKeyBlock, rawOffset + i, BIGINT.getLong(rawValueBlock, rawOffset + i));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,14 @@ private void serializeEntry(BlockBuilder keyBuilder, ArrayBlockBuilder valueBuil

protected void deserialize(int groupId, SqlMap serializedState)
{
for (int i = 0; i < serializedState.getPositionCount(); i += 2) {
int keyId = putKeyIfAbsent(groupId, serializedState, i);
Block array = new ArrayType(valueArrayBuilder.type()).getObject(serializedState, i + 1);
int rawOffset = serializedState.getRawOffset();
Block rawKeyBlock = serializedState.getRawKeyBlock();
Block rawValueBlock = serializedState.getRawValueBlock();

ArrayType arrayType = new ArrayType(valueArrayBuilder.type());
for (int i = 0; i < serializedState.getSize(); i++) {
int keyId = putKeyIfAbsent(groupId, rawKeyBlock, rawOffset + i);
Block array = arrayType.getObject(rawValueBlock, rawOffset + i);
verify(array.getPositionCount() > 0, "array is empty");
for (int arrayIndex = 0; arrayIndex < array.getPositionCount(); arrayIndex++) {
addKeyValue(keyId, array, arrayIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ private MapCardinalityFunction() {}
@SqlType(StandardTypes.BIGINT)
public static long mapCardinality(@SqlType("map(K,V)") SqlMap sqlMap)
{
return sqlMap.getPositionCount() / 2;
return sqlMap.getSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.trino.metadata.SqlScalarFunction;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BufferedMapValueBuilder;
import io.trino.spi.block.SqlMap;
import io.trino.spi.function.BoundSignature;
Expand Down Expand Up @@ -117,8 +118,9 @@ public static SqlMap mapConcat(MapType mapType, BlockPositionIsDistinctFrom keys
int lastMapIndex = maps.length - 1;
int firstMapIndex = lastMapIndex;
for (int i = 0; i < maps.length; i++) {
maxEntries += maps[i].getPositionCount() / 2;
if (maps[i].getPositionCount() > 0) {
int size = maps[i].getSize();
if (size > 0) {
maxEntries += size;
lastMapIndex = i;
firstMapIndex = min(firstMapIndex, i);
}
Expand All @@ -136,30 +138,44 @@ public static SqlMap mapConcat(MapType mapType, BlockPositionIsDistinctFrom keys
BlockSet set = new BlockSet(keyType, keysDistinctOperator, keyHashCode, maxEntries);
return mapValueBuilder.build(maxEntries, (keyBuilder, valueBuilder) -> {
// the last map
Block map = maps[last];
for (int i = 0; i < map.getPositionCount(); i += 2) {
set.add(map, i);
keyType.appendTo(map, i, keyBuilder);
valueType.appendTo(map, i + 1, valueBuilder);
SqlMap map = maps[last];
int rawOffset = map.getRawOffset();
Block rawKeyBlock = map.getRawKeyBlock();
Block rawValueBlock = map.getRawValueBlock();
for (int i = 0; i < map.getSize(); i++) {
set.add(rawKeyBlock, rawOffset + i);
writeEntry(keyType, valueType, keyBuilder, valueBuilder, rawKeyBlock, rawValueBlock, rawOffset + i);
}

// the map between the last and the first
for (int idx = last - 1; idx > first; idx--) {
map = maps[idx];
for (int i = 0; i < map.getPositionCount(); i += 2) {
if (set.add(map, i)) {
keyType.appendTo(map, i, keyBuilder);
valueType.appendTo(map, i + 1, valueBuilder);
rawOffset = map.getRawOffset();
rawKeyBlock = map.getRawKeyBlock();
rawValueBlock = map.getRawValueBlock();
for (int i = 0; i < map.getSize(); i++) {
if (set.add(rawKeyBlock, rawOffset + i)) {
writeEntry(keyType, valueType, keyBuilder, valueBuilder, rawKeyBlock, rawValueBlock, rawOffset + i);
}
}
}

// the first map
map = maps[first];
for (int i = 0; i < map.getPositionCount(); i += 2) {
if (!set.contains(map, i)) {
keyType.appendTo(map, i, keyBuilder);
valueType.appendTo(map, i + 1, valueBuilder);
rawOffset = map.getRawOffset();
rawKeyBlock = map.getRawKeyBlock();
rawValueBlock = map.getRawValueBlock();
for (int i = 0; i < map.getSize(); i++) {
if (!set.contains(rawKeyBlock, rawOffset + i)) {
writeEntry(keyType, valueType, keyBuilder, valueBuilder, rawKeyBlock, rawValueBlock, rawOffset + i);
}
}
});
}

private static void writeEntry(Type keyType, Type valueType, BlockBuilder keyBuilder, BlockBuilder valueBuilder, Block rawKeyBlock, Block rawValueBlock, int rawIndex)
{
keyType.appendTo(rawKeyBlock, rawIndex, keyBuilder);
valueType.appendTo(rawValueBlock, rawIndex, valueBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,40 +102,40 @@ else if (keyType.getJavaType() == double.class) {
@UsedByGeneratedCode
public static Object elementAt(Type valueType, SqlMap sqlMap, boolean key)
{
int valuePosition = sqlMap.seekKeyExact(key);
if (valuePosition == -1) {
int index = sqlMap.seekKeyExact(key);
if (index == -1) {
return null;
}
return readNativeValue(valueType, sqlMap, valuePosition);
return readNativeValue(valueType, sqlMap.getRawValueBlock(), sqlMap.getRawOffset() + index);
}

@UsedByGeneratedCode
public static Object elementAt(Type valueType, SqlMap sqlMap, long key)
{
int valuePosition = sqlMap.seekKeyExact(key);
if (valuePosition == -1) {
int index = sqlMap.seekKeyExact(key);
if (index == -1) {
return null;
}
return readNativeValue(valueType, sqlMap, valuePosition);
return readNativeValue(valueType, sqlMap.getRawValueBlock(), sqlMap.getRawOffset() + index);
}

@UsedByGeneratedCode
public static Object elementAt(Type valueType, SqlMap sqlMap, double key)
{
int valuePosition = sqlMap.seekKeyExact(key);
if (valuePosition == -1) {
int index = sqlMap.seekKeyExact(key);
if (index == -1) {
return null;
}
return readNativeValue(valueType, sqlMap, valuePosition);
return readNativeValue(valueType, sqlMap.getRawValueBlock(), sqlMap.getRawOffset() + index);
}

@UsedByGeneratedCode
public static Object elementAt(Type valueType, SqlMap sqlMap, Object key)
{
int valuePosition = sqlMap.seekKeyExact(key);
if (valuePosition == -1) {
int index = sqlMap.seekKeyExact(key);
if (index == -1) {
return null;
}
return readNativeValue(valueType, sqlMap, valuePosition);
return readNativeValue(valueType, sqlMap.getRawValueBlock(), sqlMap.getRawOffset() + index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,21 @@ public Block mapFromEntries(
@SqlType("map(K,V)") SqlMap sqlMap)
{
verify(rowType.getTypeParameters().size() == 2);
verify(sqlMap.getPositionCount() % 2 == 0);

Type keyType = rowType.getTypeParameters().get(0);
Type valueType = rowType.getTypeParameters().get(1);

int entryCount = sqlMap.getPositionCount() / 2;
return arrayValueBuilder.build(entryCount, valueBuilder -> {
for (int i = 0; i < entryCount; i++) {
int position = 2 * i;
int size = sqlMap.getSize();
int rawOffset = sqlMap.getRawOffset();
Block rawKeyBlock = sqlMap.getRawKeyBlock();
Block rawValueBlock = sqlMap.getRawValueBlock();

return arrayValueBuilder.build(size, valueBuilder -> {
for (int i = 0; i < size; i++) {
int offset = rawOffset + i;
((RowBlockBuilder) valueBuilder).buildEntry(fieldBuilders -> {
keyType.appendTo(sqlMap, position, fieldBuilders.get(0));
valueType.appendTo(sqlMap, position + 1, fieldBuilders.get(1));
keyType.appendTo(rawKeyBlock, offset, fieldBuilders.get(0));
valueType.appendTo(rawValueBlock, offset, fieldBuilders.get(1));
});
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@
import static io.airlift.bytecode.expression.BytecodeExpressions.and;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull;
import static io.airlift.bytecode.expression.BytecodeExpressions.divide;
import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan;
import static io.airlift.bytecode.expression.BytecodeExpressions.notEqual;
import static io.airlift.bytecode.instruction.VariableInstruction.incrementVariable;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
Expand Down Expand Up @@ -140,8 +138,7 @@ private static MethodHandle generateFilter(MapType mapType)
body.append(mapValueBuilder.set(state.cast(BufferedMapValueBuilder.class)));

BytecodeExpression mapEntryBuilder = generateMetafactory(MapValueBuilder.class, filterKeyValue, ImmutableList.of(map, function));
BytecodeExpression entryCount = divide(map.invoke("getPositionCount", int.class), constantInt(2));
body.append(mapValueBuilder.invoke("build", SqlMap.class, entryCount, mapEntryBuilder).ret());
body.append(mapValueBuilder.invoke("build", SqlMap.class, map.invoke("getSize", int.class), mapEntryBuilder).ret());

Class<?> generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapFilterFunction.class.getClassLoader());
return methodHandle(generatedClass, "filter", Object.class, SqlMap.class, BinaryFunctionInterface.class);
Expand All @@ -167,20 +164,21 @@ private static MethodDefinition generateFilterInner(ClassDefinition definition,
Class<?> keyJavaType = Primitives.wrap(keyType.getJavaType());
Class<?> valueJavaType = Primitives.wrap(valueType.getJavaType());

Variable positionCount = scope.declareVariable(int.class, "positionCount");
Variable position = scope.declareVariable(int.class, "position");
Variable size = scope.declareVariable("size", body, map.invoke("getSize", int.class));
Variable rawOffset = scope.declareVariable("rawOffset", body, map.invoke("getRawOffset", int.class));
Variable rawKeyBlock = scope.declareVariable("rawKeyBlock", body, map.invoke("getRawKeyBlock", Block.class));
Variable rawValueBlock = scope.declareVariable("rawValueBlock", body, map.invoke("getRawValueBlock", Block.class));

Variable index = scope.declareVariable(int.class, "index");
Variable keyElement = scope.declareVariable(keyJavaType, "keyElement");
Variable valueElement = scope.declareVariable(valueJavaType, "valueElement");
Variable keep = scope.declareVariable(Boolean.class, "keep");

// invoke map.getPositionCount()
body.append(positionCount.set(map.invoke("getPositionCount", int.class)));

SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType);
BytecodeNode loadKeyElement;
if (!keyType.equals(UNKNOWN)) {
// key element must be non-null
loadKeyElement = new BytecodeBlock().append(keyElement.set(keySqlType.getValue(map.cast(Block.class), position).cast(keyJavaType)));
loadKeyElement = keyElement.set(keySqlType.getValue(rawKeyBlock, add(index, rawOffset)).cast(keyJavaType));
}
else {
loadKeyElement = new BytecodeBlock().append(keyElement.set(constantNull(keyJavaType)));
Expand All @@ -190,27 +188,27 @@ private static MethodDefinition generateFilterInner(ClassDefinition definition,
BytecodeNode loadValueElement;
if (!valueType.equals(UNKNOWN)) {
loadValueElement = new IfStatement()
.condition(map.invoke("isNull", boolean.class, add(position, constantInt(1))))
.condition(rawValueBlock.invoke("isNull", boolean.class, add(index, rawOffset)))
.ifTrue(valueElement.set(constantNull(valueJavaType)))
.ifFalse(valueElement.set(valueSqlType.getValue(map.cast(Block.class), add(position, constantInt(1))).cast(valueJavaType)));
.ifFalse(valueElement.set(valueSqlType.getValue(rawValueBlock, add(index, rawOffset)).cast(valueJavaType)));
}
else {
loadValueElement = new BytecodeBlock().append(valueElement.set(constantNull(valueJavaType)));
}

body.append(new ForLoop()
.initialize(position.set(constantInt(0)))
.condition(lessThan(position, positionCount))
.update(incrementVariable(position, (byte) 2))
.initialize(index.set(constantInt(0)))
.condition(lessThan(index, size))
.update(index.increment())
.body(new BytecodeBlock()
.append(loadKeyElement)
.append(loadValueElement)
.append(keep.set(function.invoke("apply", Object.class, keyElement.cast(Object.class), valueElement.cast(Object.class)).cast(Boolean.class)))
.append(new IfStatement("if (keep != null && keep) ...")
.condition(and(notEqual(keep, constantNull(Boolean.class)), keep.cast(boolean.class)))
.ifTrue(new BytecodeBlock()
.append(keySqlType.invoke("appendTo", void.class, map.cast(Block.class), position, keyBuilder))
.append(valueSqlType.invoke("appendTo", void.class, map.cast(Block.class), add(position, constantInt(1)), valueBuilder))))));
.append(keySqlType.invoke("appendTo", void.class, rawKeyBlock, add(index, rawOffset), keyBuilder))
.append(valueSqlType.invoke("appendTo", void.class, rawValueBlock, add(index, rawOffset), valueBuilder))))));
body.ret();

return method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.operator.scalar;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.SqlMap;
import io.trino.spi.function.Description;
import io.trino.spi.function.ScalarFunction;
Expand All @@ -35,10 +34,6 @@ public static Block getKeys(
@TypeParameter("K") Type keyType,
@SqlType("map(K,V)") SqlMap sqlMap)
{
BlockBuilder blockBuilder = keyType.createBlockBuilder(null, sqlMap.getPositionCount() / 2);
for (int i = 0; i < sqlMap.getPositionCount(); i += 2) {
keyType.appendTo(sqlMap, i, blockBuilder);
}
return blockBuilder.build();
return sqlMap.getRawKeyBlock().getRegion(sqlMap.getRawOffset(), sqlMap.getSize());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,41 +108,41 @@ else if (keyType.getJavaType() == double.class) {
@UsedByGeneratedCode
public static Object subscript(MissingKeyExceptionFactory missingKeyExceptionFactory, Type keyType, Type valueType, ConnectorSession session, SqlMap sqlMap, boolean key)
{
int valuePosition = sqlMap.seekKeyExact(key);
if (valuePosition == -1) {
int index = sqlMap.seekKeyExact(key);
if (index == -1) {
throw missingKeyExceptionFactory.create(session, key);
}
return readNativeValue(valueType, sqlMap, valuePosition);
return readNativeValue(valueType, sqlMap.getRawValueBlock(), sqlMap.getRawOffset() + index);
}

@UsedByGeneratedCode
public static Object subscript(MissingKeyExceptionFactory missingKeyExceptionFactory, Type keyType, Type valueType, ConnectorSession session, SqlMap sqlMap, long key)
{
int valuePosition = sqlMap.seekKeyExact(key);
if (valuePosition == -1) {
int index = sqlMap.seekKeyExact(key);
if (index == -1) {
throw missingKeyExceptionFactory.create(session, key);
}
return readNativeValue(valueType, sqlMap, valuePosition);
return readNativeValue(valueType, sqlMap.getRawValueBlock(), sqlMap.getRawOffset() + index);
}

@UsedByGeneratedCode
public static Object subscript(MissingKeyExceptionFactory missingKeyExceptionFactory, Type keyType, Type valueType, ConnectorSession session, SqlMap sqlMap, double key)
{
int valuePosition = sqlMap.seekKeyExact(key);
if (valuePosition == -1) {
int index = sqlMap.seekKeyExact(key);
if (index == -1) {
throw missingKeyExceptionFactory.create(session, key);
}
return readNativeValue(valueType, sqlMap, valuePosition);
return readNativeValue(valueType, sqlMap.getRawValueBlock(), sqlMap.getRawOffset() + index);
}

@UsedByGeneratedCode
public static Object subscript(MissingKeyExceptionFactory missingKeyExceptionFactory, Type keyType, Type valueType, ConnectorSession session, SqlMap sqlMap, Object key)
{
int valuePosition = sqlMap.seekKeyExact(key);
if (valuePosition == -1) {
int index = sqlMap.seekKeyExact(key);
if (index == -1) {
throw missingKeyExceptionFactory.create(session, key);
}
return readNativeValue(valueType, sqlMap, valuePosition);
return readNativeValue(valueType, sqlMap.getRawValueBlock(), sqlMap.getRawOffset() + index);
}

private static class MissingKeyExceptionFactory
Expand Down
Loading

0 comments on commit ba155d5

Please sign in to comment.