Skip to content

Commit

Permalink
[FLINK-2897] [runtime] Use distinct initial indices for OutputEmitter…
Browse files Browse the repository at this point in the history
… round-robin

This closes apache#1292
  • Loading branch information
greghogan authored and StephanEwen committed Dec 7, 2015
1 parent 868f97c commit 22ac65b
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1256,18 +1256,19 @@ public static <T> Collector<T> getOutputCollector(AbstractInvokable task, TaskCo
{
// create the OutputEmitter from output ship strategy
final ShipStrategyType strategy = config.getOutputShipStrategy(i);
final int indexInSubtaskGroup = task.getIndexInSubtaskGroup();
final TypeComparatorFactory<T> compFactory = config.getOutputComparator(i, cl);

final ChannelSelector<SerializationDelegate<T>> oe;
if (compFactory == null) {
oe = new OutputEmitter<T>(strategy);
oe = new OutputEmitter<T>(strategy, indexInSubtaskGroup);
}
else {
final DataDistribution dataDist = config.getOutputDataDistribution(i, cl);
final Partitioner<?> partitioner = config.getOutputPartitioner(i, cl);

final TypeComparator<T> comparator = compFactory.createComparator();
oe = new OutputEmitter<T>(strategy, comparator, partitioner, dataDist);
oe = new OutputEmitter<T>(strategy, indexInSubtaskGroup, comparator, partitioner, dataDist);
}

final RecordWriter<SerializationDelegate<T>> recordWriter =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T>> {

private final ShipStrategyType strategy; // the shipping strategy used by this output emitter

private int[] channels; // the reused array defining target channels

private int nextChannelToSendTo = 0; // counter to go over channels round robin

private final TypeComparator<T> comparator; // the comparator for hashing / sorting
Expand All @@ -47,16 +47,17 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
* Creates a new channel selector that distributes data round robin.
*/
public OutputEmitter() {
this(ShipStrategyType.NONE);
this(ShipStrategyType.NONE, 0);
}

/**
* Creates a new channel selector that uses the given strategy (broadcasting, partitioning, ...).
* Creates a new channel selector that uses the given strategy (broadcasting, partitioning, ...)
* and uses the supplied task index perform a round robin distribution.
*
* @param strategy The distribution strategy to be used.
*/
public OutputEmitter(ShipStrategyType strategy) {
this(strategy, null);
public OutputEmitter(ShipStrategyType strategy, int indexInSubtaskGroup) {
this(strategy, indexInSubtaskGroup, null, null, null);
}

/**
Expand All @@ -67,7 +68,7 @@ public OutputEmitter(ShipStrategyType strategy) {
* @param comparator The comparator used to hash / compare the records.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator) {
this(strategy, comparator, null, null);
this(strategy, 0, comparator, null, null);
}

/**
Expand All @@ -79,30 +80,33 @@ public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator) {
* @param distr The distribution pattern used in the case of a range partitioning.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, DataDistribution distr) {
this(strategy, comparator, null, distr);
this(strategy, 0, comparator, null, distr);
}

public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Partitioner<?> partitioner) {
this(strategy, comparator, partitioner, null);
this(strategy, 0, comparator, partitioner, null);
}

@SuppressWarnings("unchecked")
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Partitioner<?> partitioner, DataDistribution distr) {
public OutputEmitter(ShipStrategyType strategy, int indexInSubtaskGroup, TypeComparator<T> comparator, Partitioner<?> partitioner, DataDistribution distr) {
if (strategy == null) {
throw new NullPointerException();
}

this.strategy = strategy;
this.nextChannelToSendTo = indexInSubtaskGroup;
this.comparator = comparator;
this.partitioner = (Partitioner<Object>) partitioner;

switch (strategy) {
case PARTITION_CUSTOM:
extractedKeys = new Object[1];
case FORWARD:
case PARTITION_HASH:
case PARTITION_RANGE:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
case PARTITION_CUSTOM:
channels = new int[1];
case BROADCAST:
break;
default:
Expand All @@ -125,6 +129,7 @@ public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Pa
public final int[] selectChannels(SerializationDelegate<T> record, int numberOfChannels) {
switch (strategy) {
case FORWARD:
return forward();
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
return robin(numberOfChannels);
Expand All @@ -143,16 +148,24 @@ public final int[] selectChannels(SerializationDelegate<T> record, int numberOfC

// --------------------------------------------------------------------------------------------

private int[] forward() {
return this.channels;
}

private int[] robin(int numberOfChannels) {
if (this.channels == null || this.channels.length != 1) {
this.channels = new int[1];
int nextChannel = this.nextChannelToSendTo;

if (nextChannel >= numberOfChannels) {
if (nextChannel == numberOfChannels) {
nextChannel = 0;
} else {
nextChannel %= numberOfChannels;
}
}

int nextChannel = nextChannelToSendTo + 1;
nextChannel = nextChannel < numberOfChannels ? nextChannel : 0;

this.nextChannelToSendTo = nextChannel;

this.channels[0] = nextChannel;
this.nextChannelToSendTo = nextChannel + 1;

return this.channels;
}

Expand All @@ -168,10 +181,6 @@ private int[] broadcast(int numberOfChannels) {
}

private int[] hashPartitionDefault(T record, int numberOfChannels) {
if (channels == null || channels.length != 1) {
channels = new int[1];
}

int hash = this.comparator.hash(record);

hash = murmurHash(hash);
Expand Down Expand Up @@ -212,11 +221,6 @@ private int[] rangePartition(T record, int numberOfChannels) {
}

private int[] customPartition(T record, int numberOfChannels) {
if (channels == null) {
channels = new int[1];
extractedKeys = new Object[1];
}

try {
if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
final Object key = extractedKeys[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,53 +151,115 @@ public void testPartitionHash() {
assertTrue(chans.length == 1);
assertTrue(chans[0] >= 0 && chans[0] <= numChans-1);
}

@Test
public void testForward() {
// Test for IntValue
@SuppressWarnings("unchecked")
final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] {0}, new Class[] {IntValue.class}).createComparator();
final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.FORWARD, intComp);
final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer());

int numChannels = 100;
int numRecords = 50000;
int numRecords = 50000 + numChannels / 2;

int[] hit = new int[numChannels];

for (int i = 0; i < numRecords; i++) {
IntValue k = new IntValue(i);
Record rec = new Record(k);
delegate.setInstance(rec);

int[] chans = oe1.selectChannels(delegate, hit.length);
for(int j=0; j < chans.length; j++) {
hit[chans[j]]++;
}
}

int cnt = 0;
for (int i = 0; i < hit.length; i++) {
assertTrue(hit[i] == (numRecords/numChannels) || hit[i] == (numRecords/numChannels)-1);
cnt += hit[i];
assertTrue(hit[0] == numRecords);
for (int i = 1; i < hit.length; i++) {
assertTrue(hit[i] == 0);
}
assertTrue(cnt == numRecords);

// Test for StringValue
@SuppressWarnings("unchecked")
final TypeComparator<Record> stringComp = new RecordComparatorFactory(new int[] {0}, new Class[] {StringValue.class}).createComparator();
final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(ShipStrategyType.FORWARD, stringComp);

numChannels = 100;
numRecords = 10000;
numRecords = 10000 + numChannels / 2;

hit = new int[numChannels];

for (int i = 0; i < numRecords; i++) {
StringValue k = new StringValue(i + "");
Record rec = new Record(k);
delegate.setInstance(rec);


int[] chans = oe2.selectChannels(delegate, hit.length);
for(int j=0; j < chans.length; j++) {
hit[chans[j]]++;
}
}

assertTrue(hit[0] == numRecords);
for (int i = 1; i < hit.length; i++) {
assertTrue(hit[i] == 0);
}
}

@Test
public void testForcedRebalance() {
// Test for IntValue
int numChannels = 100;
int toTaskIndex = numChannels * 6/7;
int fromTaskIndex = toTaskIndex + numChannels;
int extraRecords = numChannels * 1/3;
int numRecords = 50000 + extraRecords;

final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex);
final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer());

int[] hit = new int[numChannels];

for (int i = 0; i < numRecords; i++) {
IntValue k = new IntValue(i);
Record rec = new Record(k);
delegate.setInstance(rec);

int[] chans = oe1.selectChannels(delegate, hit.length);
for(int j=0; j < chans.length; j++) {
hit[chans[j]]++;
}
}

int cnt = 0;
for (int i = 0; i < hit.length; i++) {
if (toTaskIndex <= i || i < toTaskIndex+extraRecords-numChannels) {
assertTrue(hit[i] == (numRecords/numChannels)+1);
} else {
assertTrue(hit[i] == numRecords/numChannels);
}
cnt += hit[i];
}
assertTrue(cnt == numRecords);

// Test for StringValue
numChannels = 100;
toTaskIndex = numChannels / 5;
fromTaskIndex = toTaskIndex + 2 * numChannels;
extraRecords = numChannels * 2/9;
numRecords = 10000 + extraRecords;

final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex);

hit = new int[numChannels];

for (int i = 0; i < numRecords; i++) {
StringValue k = new StringValue(i + "");
Record rec = new Record(k);
delegate.setInstance(rec);

int[] chans = oe2.selectChannels(delegate, hit.length);
for(int j=0; j < chans.length; j++) {
hit[chans[j]]++;
Expand All @@ -206,11 +268,14 @@ public void testForward() {

cnt = 0;
for (int i = 0; i < hit.length; i++) {
assertTrue(hit[i] == (numRecords/numChannels) || hit[i] == (numRecords/numChannels)-1);
if (toTaskIndex <= i && i < toTaskIndex+extraRecords) {
assertTrue(hit[i] == (numRecords/numChannels)+1);
} else {
assertTrue(hit[i] == numRecords/numChannels);
}
cnt += hit[i];
}
assertTrue(cnt == numRecords);

}

@Test
Expand Down

0 comments on commit 22ac65b

Please sign in to comment.