Skip to content

Commit

Permalink
Fix outer index joins when multiple overlapping keys are supplied
Browse files Browse the repository at this point in the history
  • Loading branch information
erichwang committed Jun 18, 2014
1 parent 3aa8f76 commit c636171
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.index;

import com.facebook.presto.spi.RecordCursor;
import com.facebook.presto.spi.RecordSet;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import java.util.Iterator;
import java.util.List;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

/**
* Only retains rows that have identical values for each respective fieldSet.
*/
public class FieldSetFilteringRecordSet
implements RecordSet
{
private final RecordSet delegate;
private final List<Set<Integer>> fieldSets;

public FieldSetFilteringRecordSet(RecordSet delegate, List<Set<Integer>> fieldSets)
{
this.delegate = checkNotNull(delegate, "delegate is null");
this.fieldSets = ImmutableList.copyOf(checkNotNull(fieldSets, "fieldSets is null"));
}

@Override
public List<Type> getColumnTypes()
{
return delegate.getColumnTypes();
}

@Override
public RecordCursor cursor()
{
return new FieldSetFilteringRecordCursor(delegate.cursor(), fieldSets);
}

private static class FieldSetFilteringRecordCursor
implements RecordCursor
{
private final RecordCursor delegate;
private final List<Set<Integer>> fieldSets;

private FieldSetFilteringRecordCursor(RecordCursor delegate, List<Set<Integer>> fieldSets)
{
this.delegate = delegate;
this.fieldSets = fieldSets;
}

@Override
public long getTotalBytes()
{
return delegate.getTotalBytes();
}

@Override
public long getCompletedBytes()
{
return delegate.getCompletedBytes();
}

@Override
public long getReadTimeNanos()
{
return delegate.getReadTimeNanos();
}

@Override
public Type getType(int field)
{
return delegate.getType(field);
}

@Override
public boolean advanceNextPosition()
{
while (delegate.advanceNextPosition()) {
if (fieldSetsEqual(delegate, fieldSets)) {
return true;
}
}
return false;
}

private static boolean fieldSetsEqual(RecordCursor cursor, List<Set<Integer>> fieldSets)
{
for (Set<Integer> fieldSet : fieldSets) {
if (!fieldsEquals(cursor, fieldSet)) {
return false;
}
}
return true;
}

private static boolean fieldsEquals(RecordCursor cursor, Set<Integer> fields)
{
if (fields.size() < 2) {
return true; // Nothing to compare
}
Iterator<Integer> fieldIterator = fields.iterator();
int firstField = fieldIterator.next();
while (fieldIterator.hasNext()) {
if (!fieldEquals(cursor, firstField, fieldIterator.next())) {
return false;
}
}
return true;
}

private static boolean fieldEquals(RecordCursor cursor, int field1, int field2)
{
checkArgument(cursor.getType(field1).equals(cursor.getType(field2)), "Should only be comparing fields of the same type");

if (cursor.isNull(field1) || cursor.isNull(field2)) {
return false;
}

Class<?> javaType = cursor.getType(field1).getJavaType();
if (javaType == long.class) {
return cursor.getLong(field1) == cursor.getLong(field2);
}
else if (javaType == double.class) {
return cursor.getDouble(field1) == cursor.getDouble(field2);
}
else if (javaType == boolean.class) {
return cursor.getBoolean(field1) == cursor.getBoolean(field2);
}
else if (javaType == Slice.class) {
return cursor.getSlice(field1).equals(cursor.getSlice(field2));
}
else {
throw new IllegalArgumentException("Unknown java type: " + javaType);
}
}

@Override
public boolean getBoolean(int field)
{
return delegate.getBoolean(field);
}

@Override
public long getLong(int field)
{
return delegate.getLong(field);
}

@Override
public double getDouble(int field)
{
return delegate.getDouble(field);
}

@Override
public Slice getSlice(int field)
{
return delegate.getSlice(field);
}

@Override
public boolean isNull(int field)
{
return delegate.isNull(field);
}

@Override
public void close()
{
delegate.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import com.facebook.presto.spi.Index;
import com.facebook.presto.spi.RecordSet;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.split.MappedRecordSet;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.base.Function;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
Expand All @@ -49,21 +49,21 @@ public static class IndexSourceOperatorFactory
private final PlanNodeId sourceId;
private final Index index;
private final List<Type> types;
private final List<Integer> probeKeyRemap;
private final Function<RecordSet, RecordSet> probeKeyNormalizer;
private boolean closed;

public IndexSourceOperatorFactory(
int operatorId,
PlanNodeId sourceId,
Index index,
List<Type> types,
List<Integer> probeKeyRemap)
Function<RecordSet, RecordSet> probeKeyNormalizer)
{
this.operatorId = operatorId;
this.sourceId = checkNotNull(sourceId, "sourceId is null");
this.index = checkNotNull(index, "index is null");
this.types = checkNotNull(types, "types is null");
this.probeKeyRemap = ImmutableList.copyOf(checkNotNull(probeKeyRemap, "probeKeyRemap is null"));
this.probeKeyNormalizer = checkNotNull(probeKeyNormalizer, "probeKeyNormalizer is null");
}

@Override
Expand All @@ -88,7 +88,7 @@ public SourceOperator createOperator(DriverContext driverContext)
sourceId,
index,
types,
probeKeyRemap);
probeKeyNormalizer);
}

@Override
Expand All @@ -102,7 +102,7 @@ public void close()
private final PlanNodeId planNodeId;
private final Index index;
private final List<Type> types;
private final List<Integer> probeKeyRemap;
private final Function<RecordSet, RecordSet> probeKeyNormalizer;

@GuardedBy("this")
private Operator source;
Expand All @@ -112,13 +112,13 @@ public IndexSourceOperator(
PlanNodeId planNodeId,
Index index,
List<Type> types,
List<Integer> probeKeyRemap)
Function<RecordSet, RecordSet> probeKeyNormalizer)
{
this.operatorContext = checkNotNull(operatorContext, "operatorContext is null");
this.planNodeId = checkNotNull(planNodeId, "planNodeId is null");
this.index = checkNotNull(index, "index is null");
this.types = ImmutableList.copyOf(checkNotNull(types, "types is null"));
this.probeKeyRemap = ImmutableList.copyOf(checkNotNull(probeKeyRemap, "probeKeyRemap is null"));
this.probeKeyNormalizer = checkNotNull(probeKeyNormalizer, "probeKeyNormalizer is null");
}

@Override
Expand All @@ -142,9 +142,9 @@ public synchronized void addSplit(Split split)

IndexSplit indexSplit = (IndexSplit) split.getConnectorSplit();

// Remap the record set into the format the index is expecting
RecordSet recordSet = new MappedRecordSet(indexSplit.getKeyRecordSet(), probeKeyRemap);
RecordSet result = index.lookup(recordSet);
// Normalize the incoming RecordSet to something that can be consumed by the index
RecordSet normalizedRecordSet = probeKeyNormalizer.apply(indexSplit.getKeyRecordSet());
RecordSet result = index.lookup(normalizedRecordSet);
source = new RecordProjectOperator(operatorContext, result);

operatorContext.setInfoSupplier(Suppliers.ofInstance(split.getInfo()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,18 @@
import com.facebook.presto.operator.ValuesOperator.ValuesOperatorFactory;
import com.facebook.presto.operator.WindowFunctionDefinition;
import com.facebook.presto.operator.WindowOperator.WindowOperatorFactory;
import com.facebook.presto.operator.index.FieldSetFilteringRecordSet;
import com.facebook.presto.operator.index.IndexLookupSourceSupplier;
import com.facebook.presto.operator.index.IndexSourceOperator;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.Index;
import com.facebook.presto.spi.RecordSet;
import com.facebook.presto.spi.RecordSink;
import com.facebook.presto.spi.block.BlockCursor;
import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.split.DataStreamProvider;
import com.facebook.presto.split.MappedRecordSet;
import com.facebook.presto.sql.gen.ExpressionCompiler;
import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer;
import com.facebook.presto.sql.planner.plan.AggregationNode;
Expand Down Expand Up @@ -849,21 +852,41 @@ public PhysicalOperation visitIndexSource(IndexSourceNode node, LocalExecutionPl
List<Symbol> lookupSymbolSchema = ImmutableList.copyOf(node.getLookupSymbols());

// Identify how to remap the probe key Input to match the source index lookup layout
List<Integer> remappedProbeKeyChannels = new ArrayList<>();
ImmutableList.Builder<Integer> remappedProbeKeyChannelsBuilder = ImmutableList.builder();
// Identify overlapping fields that can produce the same lookup symbol.
// We will filter incoming keys to ensure that overlapping fields will have the same value.
ImmutableList.Builder<Set<Integer>> overlappingFieldSetsBuilder = ImmutableList.builder();
for (Symbol lookupSymbol : lookupSymbolSchema) {
// TODO: add additional optimization when there are multiple mappings for one lookup symbol (e.g. index key filtering)
// Currently just pick the first field that can supply this symbol
Input probeInput = Iterables.getFirst(indexLookupToProbeInput.get(lookupSymbol), null);
remappedProbeKeyChannels.add(probeInput.getChannel());
Set<Input> potentialProbeInputs = indexLookupToProbeInput.get(lookupSymbol);
checkState(!potentialProbeInputs.isEmpty(), "Must have at least one source from the probe input");
if (potentialProbeInputs.size() > 1) {
overlappingFieldSetsBuilder.add(FluentIterable.from(potentialProbeInputs)
.transform(Input.channelGetter())
.toSet());
}
remappedProbeKeyChannelsBuilder.add(Iterables.getFirst(potentialProbeInputs, null).getChannel());
}
final List<Set<Integer>> overlappingFieldSets = overlappingFieldSetsBuilder.build();
final List<Integer> remappedProbeKeyChannels = remappedProbeKeyChannelsBuilder.build();
Function<RecordSet, RecordSet> probeKeyNormalizer = new Function<RecordSet, RecordSet>()
{
@Override
public RecordSet apply(RecordSet recordSet)
{
if (!overlappingFieldSets.isEmpty()) {
recordSet = new FieldSetFilteringRecordSet(recordSet, overlappingFieldSets);
}
return new MappedRecordSet(recordSet, remappedProbeKeyChannels);
}
};

// Declare the input and output schemas for the index and acquire the actual Index
List<ColumnHandle> lookupSchema = Lists.transform(lookupSymbolSchema, Functions.forMap(node.getAssignments()));
List<ColumnHandle> outputSchema = Lists.transform(node.getOutputSymbols(), Functions.forMap(node.getAssignments()));
Index index = indexManager.getIndex(node.getIndexHandle(), lookupSchema, outputSchema);

List<Type> types = getSourceOperatorTypes(node, context.getTypes());
OperatorFactory operatorFactory = new IndexSourceOperator.IndexSourceOperatorFactory(context.getNextOperatorId(), node.getId(), index, types, remappedProbeKeyChannels);
OperatorFactory operatorFactory = new IndexSourceOperator.IndexSourceOperatorFactory(context.getNextOperatorId(), node.getId(), index, types, probeKeyNormalizer);
return new PhysicalOperation(operatorFactory, outputMappings.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,20 @@ public void testOverlappingIndexJoinLookupSymbol()
" ON l.orderkey % 1024 = o.orderkey AND l.partkey % 1024 = o.orderkey");
}

@Test
public void testOverlappingSourceOuterIndexJoinLookupSymbol()
throws Exception
{
assertQuery("" +
"SELECT *\n" +
"FROM (\n" +
" SELECT *\n" +
" FROM lineitem\n" +
" WHERE partkey % 8 = 0) l\n" +
"LEFT JOIN orders o\n" +
" ON l.orderkey % 1024 = o.orderkey AND l.partkey % 1024 = o.orderkey");
}

@Test
public void testOverlappingIndexJoinProbeSymbol()
throws Exception
Expand All @@ -327,6 +341,20 @@ public void testOverlappingIndexJoinProbeSymbol()
" ON l.orderkey = o.orderkey AND l.orderkey = o.custkey");
}

@Test
public void testOverlappingSourceOuterIndexJoinProbeSymbol()
throws Exception
{
assertQuery("" +
"SELECT *\n" +
"FROM (\n" +
" SELECT *\n" +
" FROM lineitem\n" +
" WHERE partkey % 8 = 0) l\n" +
"LEFT JOIN orders o\n" +
" ON l.orderkey = o.orderkey AND l.orderkey = o.custkey");
}

@Test
public void testRepeatedIndexJoinClause()
throws Exception
Expand Down

0 comments on commit c636171

Please sign in to comment.