Skip to content

Commit

Permalink
[FLINK-3234] [dataSet] Add KeySelector support to sortPartition opera…
Browse files Browse the repository at this point in the history
…tion.

This closes apache#1585
  • Loading branch information
chiwanpark authored and fhueske committed Feb 10, 2016
1 parent 572855d commit 0a63797
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 56 deletions.
18 changes: 18 additions & 0 deletions flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,24 @@ public SortPartitionOperator<T> sortPartition(String field, Order order) {
return new SortPartitionOperator<>(this, field, order, Utils.getCallLocationName());
}

/**
* Locally sorts the partitions of the DataSet on the extracted key in the specified order.
* The DataSet can be sorted on multiple values by returning a tuple from the KeySelector.
*
* Note that no additional sort keys can be appended to a KeySelector sort keys. To sort
* the partitions by multiple values using KeySelector, the KeySelector must return a tuple
* consisting of the values.
*
* @param keyExtractor The KeySelector function which extracts the key values from the DataSet
* on which the DataSet is sorted.
* @param order The order in which the DataSet is sorted.
* @return The DataSet with sorted local partitions.
*/
public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor, Order order) {
final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, getType());
return new SortPartitionOperator<>(this, new Keys.SelectorFunctionKeys<>(clean(keyExtractor), getType(), keyType), order, Utils.getCallLocationName());
}

// --------------------------------------------------------------------------------------------
// Top-K
// --------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.SortPartitionOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;

/**
* This operator represents a DataSet with locally sorted partitions.
Expand All @@ -38,27 +42,58 @@
@Public
public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPartitionOperator<T>> {

private int[] sortKeyPositions;
private List<Keys<T>> keys;

private Order[] sortOrders;
private List<Order> orders;

private final String sortLocationName;

private boolean useKeySelector;

public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String sortLocationName) {
private SortPartitionOperator(DataSet<T> dataSet, String sortLocationName) {
super(dataSet, dataSet.getType());

keys = new ArrayList<>();
orders = new ArrayList<>();
this.sortLocationName = sortLocationName;
}


public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String sortLocationName) {
this(dataSet, sortLocationName);
this.useKeySelector = false;

ensureSortableKey(sortField);

int[] flatOrderKeys = getFlatFields(sortField);
this.appendSorting(flatOrderKeys, sortOrder);
keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
orders.add(sortOrder);
}

public SortPartitionOperator(DataSet<T> dataSet, String sortField, Order sortOrder, String sortLocationName) {
super(dataSet, dataSet.getType());
this.sortLocationName = sortLocationName;
this(dataSet, sortLocationName);
this.useKeySelector = false;

ensureSortableKey(sortField);

keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
orders.add(sortOrder);
}

public <K> SortPartitionOperator(DataSet<T> dataSet, Keys.SelectorFunctionKeys<T, K> sortKey, Order sortOrder, String sortLocationName) {
this(dataSet, sortLocationName);
this.useKeySelector = true;

ensureSortableKey(sortKey);

int[] flatOrderKeys = getFlatFields(sortField);
this.appendSorting(flatOrderKeys, sortOrder);
keys.add(sortKey);
orders.add(sortOrder);
}

/**
* Returns whether using key selector or not.
*/
public boolean useKeySelector() {
return useKeySelector;
}

/**
Expand All @@ -70,9 +105,14 @@ public SortPartitionOperator(DataSet<T> dataSet, String sortField, Order sortOrd
* @return The DataSet with sorted local partitions.
*/
public SortPartitionOperator<T> sortPartition(int field, Order order) {
if (useKeySelector) {
throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector");
}

ensureSortableKey(field);
keys.add(new Keys.ExpressionKeys<>(field, getType()));
orders.add(order);

int[] flatOrderKeys = getFlatFields(field);
this.appendSorting(flatOrderKeys, order);
return this;
}

Expand All @@ -81,58 +121,41 @@ public SortPartitionOperator<T> sortPartition(int field, Order order) {
* local partition sorting of the DataSet.
*
* @param field The field expression referring to the field of the additional sort order of
* the local partition sorting.
* @param order The order of the additional sort order of the local partition sorting.
* the local partition sorting.
* @param order The order of the additional sort order of the local partition sorting.
* @return The DataSet with sorted local partitions.
*/
public SortPartitionOperator<T> sortPartition(String field, Order order) {
int[] flatOrderKeys = getFlatFields(field);
this.appendSorting(flatOrderKeys, order);
if (useKeySelector) {
throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector");
}

ensureSortableKey(field);
keys.add(new Keys.ExpressionKeys<>(field, getType()));
orders.add(order);

return this;
}

// --------------------------------------------------------------------------------------------
// Key Extraction
// --------------------------------------------------------------------------------------------

private int[] getFlatFields(int field) {
public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor, Order order) {
throw new InvalidProgramException("KeySelector cannot be chained.");
}

if (!Keys.ExpressionKeys.isSortKey(field, super.getType())) {
private void ensureSortableKey(int field) throws InvalidProgramException {
if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
throw new InvalidProgramException("Selected sort key is not a sortable type");
}

Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(field, super.getType());
return ek.computeLogicalKeyPositions();
}

private int[] getFlatFields(String fields) {

if (!Keys.ExpressionKeys.isSortKey(fields, super.getType())) {
private void ensureSortableKey(String field) throws InvalidProgramException {
if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
throw new InvalidProgramException("Selected sort key is not a sortable type");
}

Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(fields, super.getType());
return ek.computeLogicalKeyPositions();
}

private void appendSorting(int[] flatOrderFields, Order order) {

if(this.sortKeyPositions == null) {
// set sorting info
this.sortKeyPositions = flatOrderFields;
this.sortOrders = new Order[flatOrderFields.length];
Arrays.fill(this.sortOrders, order);
} else {
// append sorting info to exising info
int oldLength = this.sortKeyPositions.length;
int newLength = oldLength + flatOrderFields.length;
this.sortKeyPositions = Arrays.copyOf(this.sortKeyPositions, newLength);
this.sortOrders = Arrays.copyOf(this.sortOrders, newLength);

for(int i=0; i<flatOrderFields.length; i++) {
this.sortKeyPositions[oldLength+i] = flatOrderFields[i];
this.sortOrders[oldLength+i] = order;
}
private <K> void ensureSortableKey(Keys.SelectorFunctionKeys<T, K> sortKey) {
if (!sortKey.getKeyType().isSortKeyType()) {
throw new InvalidProgramException("Selected sort key is not a sortable type");
}
}

Expand All @@ -144,16 +167,33 @@ private void appendSorting(int[] flatOrderFields, Order order) {

String name = "Sort at " + sortLocationName;

if (useKeySelector) {
return translateToDataFlowWithKeyExtractor(input, (Keys.SelectorFunctionKeys<T, ?>) keys.get(0), orders.get(0), name);
}

// flatten sort key positions
List<Integer> allKeyPositions = new ArrayList<>();
List<Order> allOrders = new ArrayList<>();
for (int i = 0, length = keys.size(); i < length; i++) {
int[] sortKeyPositions = keys.get(i).computeLogicalKeyPositions();
Order order = orders.get(i);

for (int sortKeyPosition : sortKeyPositions) {
allKeyPositions.add(sortKeyPosition);
allOrders.add(order);
}
}

Ordering partitionOrdering = new Ordering();
for (int i = 0; i < this.sortKeyPositions.length; i++) {
partitionOrdering.appendOrdering(this.sortKeyPositions[i], null, this.sortOrders[i]);
for (int i = 0, length = allKeyPositions.size(); i < length; i++) {
partitionOrdering.appendOrdering(allKeyPositions.get(i), null, allOrders.get(i));
}

// distinguish between partition types
UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<>(getType(), getType());
SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
noop.setInput(input);
if(this.getParallelism() < 0) {
if (this.getParallelism() < 0) {
// use parallelism of input if not explicitly specified
noop.setParallelism(input.getParallelism());
} else {
Expand All @@ -165,4 +205,32 @@ private void appendSorting(int[] flatOrderFields, Order order) {

}

private <K> org.apache.flink.api.common.operators.SingleInputOperator<?, T, ?> translateToDataFlowWithKeyExtractor(
Operator<T> input, Keys.SelectorFunctionKeys<T, K> keys, Order order, String name) {
TypeInformation<Tuple2<K, T>> typeInfoWithKey = KeyFunctions.createTypeWithKey(keys);
Keys.ExpressionKeys<Tuple2<K, T>> newKey = new Keys.ExpressionKeys<>(0, typeInfoWithKey);

Operator<Tuple2<K, T>> keyedInput = KeyFunctions.appendKeyExtractor(input, keys);

int[] sortKeyPositions = newKey.computeLogicalKeyPositions();
Ordering partitionOrdering = new Ordering();
for (int keyPosition : sortKeyPositions) {
partitionOrdering.appendOrdering(keyPosition, null, order);
}

// distinguish between partition types
UnaryOperatorInformation<Tuple2<K, T>, Tuple2<K, T>> operatorInfo = new UnaryOperatorInformation<>(typeInfoWithKey, typeInfoWithKey);
SortPartitionOperatorBase<Tuple2<K, T>> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
noop.setInput(keyedInput);
if (this.getParallelism() < 0) {
// use parallelism of input if not explicitly specified
noop.setParallelism(input.getParallelism());
} else {
// use explicitly specified parallelism
noop.setParallelism(this.getParallelism());
}

return KeyFunctions.appendKeyRemover(noop, keys);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,88 @@ public void testSortPartitionWithExpressionKeys4() {
tupleDs.sortPartition("f3", Order.ASCENDING);
}

@Test
public void testSortPartitionWithKeySelector1() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);

// should work
try {
tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, Integer>() {
@Override
public Integer getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return value.f0;
}
}, Order.ASCENDING);
} catch (Exception e) {
Assert.fail();
}
}

@Test(expected = InvalidProgramException.class)
public void testSortPartitionWithKeySelector2() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);

// must not work
tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, Long[]>() {
@Override
public Long[] getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return value.f3;
}
}, Order.ASCENDING);
}

@Test(expected = InvalidProgramException.class)
public void testSortPartitionWithKeySelector3() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);

// must not work
tupleDs
.sortPartition("f1", Order.ASCENDING)
.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>() {
@Override
public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return value.f2;
}
}, Order.ASCENDING);
}

@Test
public void testSortPartitionWithKeySelector4() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);

// should work
try {
tupleDs.sortPartition(new KeySelector<Tuple4<Integer,Long,CustomType,Long[]>, Tuple2<Integer, Long>>() {
@Override
public Tuple2<Integer, Long> getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return new Tuple2<>(value.f0, value.f1);
}
}, Order.ASCENDING);
} catch (Exception e) {
Assert.fail();
}
}

@Test(expected = InvalidProgramException.class)
public void testSortPartitionWithKeySelector5() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);

// must not work
tupleDs
.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>() {
@Override
public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return value.f2;
}
}, Order.ASCENDING)
.sortPartition("f1", Order.ASCENDING);
}

public static class CustomType implements Serializable {

public static class Nest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,31 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
new SortPartitionOperator[T](javaSet, field, order, getCallLocationName()))
}

/**
* Locally sorts the partitions of the DataSet on the extracted key in the specified order.
* The DataSet can be sorted on multiple values by returning a tuple from the KeySelector.
*
* Note that no additional sort keys can be appended to a KeySelector sort keys. To sort
* the partitions by multiple values using KeySelector, the KeySelector must return a tuple
* consisting of the values.
*/
def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] ={
val keyExtractor = new KeySelector[T, K] {
val cleanFun = clean(fun)
def getKey(in: T) = cleanFun(in)
}

val keyType = implicitly[TypeInformation[K]]
new PartitionSortedDataSet[T](
new SortPartitionOperator[T](javaSet,
new Keys.SelectorFunctionKeys[T, K](
keyExtractor,
javaSet.getType,
keyType),
order,
getCallLocationName()))
}

// --------------------------------------------------------------------------------------------
// Result writing
// --------------------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 0a63797

Please sign in to comment.