Skip to content

Commit

Permalink
Added support for extra inputs to RegularPactTask and TaskConfig.
Browse files Browse the repository at this point in the history
  • Loading branch information
aalexandrov authored and StephanEwen committed Feb 13, 2014
1 parent f45357c commit 3992fc6
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -99,7 +100,7 @@ public class RegularPactTask<S extends Function, OT> extends AbstractTask implem
* The instantiated user code of this task's main driver.
*/
protected S stub;

/**
* The collector that forwards the user code's results. May forward to a channel or to chained drivers within
* this task.
Expand All @@ -117,11 +118,22 @@ public class RegularPactTask<S extends Function, OT> extends AbstractTask implem
*/
protected MutableReader<?>[] inputReaders;

/**
* The input readers for the configured broadcast variables for this task.
*/
protected MutableReader<?>[] broadcastInputReaders;

/**
* The inputs reader, wrapped in an iterator. Prior to the local strategies, etc...
*/
protected MutableObjectIterator<?>[] inputIterators;

/**
* The input readers for the configured broadcast variables, wrapped in an iterator.
* Prior to the local strategies, etc...
*/
protected MutableObjectIterator<?>[] broadcastInputIterators;

/**
* The local strategies that are applied on the inputs.
*/
Expand All @@ -143,12 +155,18 @@ public class RegularPactTask<S extends Function, OT> extends AbstractTask implem
* and the temp-table barrier.
*/
protected MutableObjectIterator<?>[] inputs;


/**
* The serializers for the input data type.
*/
protected TypeSerializer<?>[] inputSerializers;

/**
* The serializers for the broadcast input data types.
*/
protected TypeSerializer<?>[] broadcastInputSerializers;

/**
* The comparators for the central driver.
*/
Expand Down Expand Up @@ -234,6 +252,7 @@ public void registerInputOutput() {
// however, this does not trigger any local processing.
try {
initInputReaders();
initBroadcastInputReaders();
} catch (Exception e) {
throw new RuntimeException("Initializing the input streams failed" +
e.getMessage() == null ? "." : ": " + e.getMessage(), e);
Expand Down Expand Up @@ -274,7 +293,9 @@ public void invoke() throws Exception {
// the local processing includes building the dams / caches
try {
int numInputs = driver.getNumberOfInputs();
int numBroadcastInputs = this.config.getNumBroadcastInputs();
initInputsSerializersAndComparators(numInputs);
initBroadcastInputsSerializers(numBroadcastInputs);
initLocalStrategies(numInputs);
} catch (Exception e) {
throw new RuntimeException("Initializing the input processing failed" +
Expand Down Expand Up @@ -384,6 +405,23 @@ protected void run() throws Exception {
return;
}

// drain the broadcast inputs
for (int i = 0; i < this.config.getNumBroadcastInputs(); i++) {
final String name = this.config.getBroadcastInputName(i);
@SuppressWarnings("unchecked")
final MutableObjectIterator<Object> reader = (MutableObjectIterator<Object>) this.broadcastInputIterators[i];
@SuppressWarnings("unchecked")
final TypeSerializer<Object> serializer = (TypeSerializer<Object>) this.broadcastInputSerializers[i];

Collection<Object> collection = new ArrayList<Object>();
Object record = serializer.createInstance();
while (this.running && reader.next(record)) {
collection.add(record);
record = serializer.createInstance();
}
this.stub.getRuntimeContext().setBroadcastVariable(name, collection);
}

// start all chained tasks
RegularPactTask.openChainedTasks(this.chainedTasks, this);

Expand Down Expand Up @@ -417,9 +455,9 @@ protected void run() throws Exception {
// JobManager. close() has been called earlier for all involved UDFs
// (using this.stub.close() and closeChainedTasks()), so UDFs can no longer
// modify accumulators.ll;
if (stub != null) {
if (this.stub != null) {
// collect the counters from the stub
Map<String, Accumulator<?,?>> accumulators = stub.getRuntimeContext().getAllAccumulators();
Map<String, Accumulator<?,?>> accumulators = this.stub.getRuntimeContext().getAllAccumulators();
RegularPactTask.reportAndClearAccumulators(getEnvironment(), accumulators, this.chainedTasks);
}
}
Expand Down Expand Up @@ -621,14 +659,45 @@ protected void initInputReaders() throws Exception {
}
}

/**
* Creates the record readers for the extra broadcast inputs as configured by {@link TaskConfig#getNumBroadcastInputs()}.
*
* This method requires that the task configuration, the driver, and the user-code class loader are set.
*/
@SuppressWarnings("unchecked")
protected void initBroadcastInputReaders() throws Exception {
final int numBroadcastInputs = this.config.getNumBroadcastInputs();
final MutableReader<?>[] broadcastInputReaders = new MutableReader[numBroadcastInputs];

for (int i = 0; i < this.config.getNumBroadcastInputs(); i++) {
// ---------------- create the input readers ---------------------
// in case where a logical input unions multiple physical inputs, create a union reader
final int groupSize = this.config.getBroadcastGroupSize(i);
if (groupSize == 1) {
// non-union case
broadcastInputReaders[i] = new MutableRecordReader<IOReadableWritable>(this);
} else if (groupSize > 1){
// union case
MutableRecordReader<IOReadableWritable>[] readers = new MutableRecordReader[groupSize];
for (int j = 0; j < groupSize; ++j) {
readers[j] = new MutableRecordReader<IOReadableWritable>(this);
}
broadcastInputReaders[i] = new MutableUnionRecordReader<IOReadableWritable>(readers);
} else {
throw new Exception("Illegal input group size in task configuration: " + groupSize);
}
}
this.broadcastInputReaders = broadcastInputReaders;
}

/**
* Creates all the serializers and comparators.
*/
protected void initInputsSerializersAndComparators(int numInputs) throws Exception {
this.inputSerializers = new TypeSerializer[numInputs];
this.inputComparators = this.driver.requiresComparatorOnInput() ? new TypeComparator[numInputs] : null;
this.inputIterators = new MutableObjectIterator[numInputs];

for (int i = 0; i < numInputs; i++) {
// ---------------- create the serializer first ---------------------
final TypeSerializerFactory<?> serializerFactory = this.config.getInputSerializer(i, this.userCodeClassLoader);
Expand All @@ -644,6 +713,22 @@ protected void initInputsSerializersAndComparators(int numInputs) throws Excepti
}
}

/**
* Creates all the serializers and iterators for the broadcast inputs.
*/
protected void initBroadcastInputsSerializers(int numBroadcastInputs) throws Exception {
this.broadcastInputSerializers = new TypeSerializer[numBroadcastInputs];
this.broadcastInputIterators = new MutableObjectIterator[numBroadcastInputs];

for (int i = 0; i < numBroadcastInputs; i++) {
// ---------------- create the serializer first ---------------------
final TypeSerializerFactory<?> serializerFactory = this.config.getInputSerializer(i, this.userCodeClassLoader);
this.broadcastInputSerializers[i] = serializerFactory.getSerializer();

this.broadcastInputIterators[i] = createInputIterator(i, this.broadcastInputReaders[i], this.broadcastInputSerializers[i]);
}
}

/**
*
* NOTE: This method must be invoked after the invocation of {@code #initInputReaders()} and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ public class TaskConfig {

private static final String NUM_INPUTS = "pact.in.num";

private static final String NUM_BROADCAST_INPUTS = "pact.in.broadcast.num";

/*
* If one input has multiple predecessors (bag union), multiple
* inputs must be grouped together. For a map or reduce there is
Expand All @@ -86,10 +88,16 @@ public class TaskConfig {
*/
private static final String INPUT_GROUP_SIZE_PREFIX = "pact.in.groupsize.";

private static final String BROADCAST_INPUT_GROUP_SIZE_PREFIX = "pact.in.broadcast.groupsize.";

private static final String INPUT_TYPE_SERIALIZER_FACTORY_PREFIX = "pact.in.serializer.";

private static final String BROADCAST_INPUT_TYPE_SERIALIZER_FACTORY_PREFIX = "pact.in.broadcast.serializer.";

private static final String INPUT_TYPE_SERIALIZER_PARAMETERS_PREFIX = "pact.in.serializer.param.";

private static final String BROADCAST_INPUT_TYPE_SERIALIZER_PARAMETERS_PREFIX = "pact.in.broadcast.serializer.param.";

private static final String INPUT_LOCAL_STRATEGY_PREFIX = "pact.in.strategy.";

private static final String INPUT_STRATEGY_COMPARATOR_FACTORY_PREFIX = "pact.in.comparator.";
Expand All @@ -102,6 +110,8 @@ public class TaskConfig {

private static final String INPUT_DAM_MEMORY_PREFIX = "pact.in.dam.mem.";

private static final String BROADCAST_INPUT_NAME_PREFIX = "pact.in.broadcast.name.";


// -------------------------------------- Outputs ---------------------------------------------

Expand Down Expand Up @@ -384,11 +394,21 @@ public void setInputSerializer(TypeSerializerFactory<?> factory, int inputNum) {
INPUT_TYPE_SERIALIZER_PARAMETERS_PREFIX + inputNum + SEPARATOR);
}

public void setBroadcastInputSerializer(TypeSerializerFactory<?> factory, int inputNum) {
setTypeSerializerFactory(factory, BROADCAST_INPUT_TYPE_SERIALIZER_FACTORY_PREFIX + inputNum,
BROADCAST_INPUT_TYPE_SERIALIZER_PARAMETERS_PREFIX + inputNum + SEPARATOR);
}

public <T> TypeSerializerFactory<T> getInputSerializer(int inputNum, ClassLoader cl) {
return getTypeSerializerFactory(INPUT_TYPE_SERIALIZER_FACTORY_PREFIX + inputNum,
INPUT_TYPE_SERIALIZER_PARAMETERS_PREFIX + inputNum + SEPARATOR, cl);
}

public <T> TypeSerializerFactory<T> getBroadcastInputSerializer(int inputNum, ClassLoader cl) {
return getTypeSerializerFactory(BROADCAST_INPUT_TYPE_SERIALIZER_FACTORY_PREFIX + inputNum,
BROADCAST_INPUT_TYPE_SERIALIZER_PARAMETERS_PREFIX + inputNum + SEPARATOR, cl);
}

public void setInputComparator(TypeComparatorFactory<?> factory, int inputNum) {
setTypeComparatorFactory(factory, INPUT_STRATEGY_COMPARATOR_FACTORY_PREFIX + inputNum,
INPUT_STRATEGY_COMPARATOR_PARAMETERS_PREFIX + inputNum + SEPARATOR);
Expand All @@ -403,16 +423,30 @@ public int getNumInputs() {
return this.config.getInteger(NUM_INPUTS, 0);
}

public int getNumBroadcastInputs() {
return this.config.getInteger(NUM_BROADCAST_INPUTS, 0);
}

public int getGroupSize(int groupIndex) {
return this.config.getInteger(INPUT_GROUP_SIZE_PREFIX + groupIndex, -1);
}

public int getBroadcastGroupSize(int groupIndex) {
return this.config.getInteger(BROADCAST_INPUT_GROUP_SIZE_PREFIX + groupIndex, -1);
}

public void addInputToGroup(int groupIndex) {
final String grp = INPUT_GROUP_SIZE_PREFIX + groupIndex;
this.config.setInteger(grp, this.config.getInteger(grp, 0) + 1);
this.config.setInteger(NUM_INPUTS, this.config.getInteger(NUM_INPUTS, 0) + 1);
}

public void addBroadcastInputToGroup(int groupIndex) {
final String grp = BROADCAST_INPUT_GROUP_SIZE_PREFIX + groupIndex;
this.config.setInteger(grp, this.config.getInteger(grp, 0) + 1);
this.config.setInteger(NUM_BROADCAST_INPUTS, this.config.getInteger(NUM_BROADCAST_INPUTS, 0) + 1);
}

public void setInputAsynchronouslyMaterialized(int inputNum, boolean temp) {
this.config.setBoolean(INPUT_DAM_PREFIX + inputNum, temp);
}
Expand All @@ -437,6 +471,14 @@ public long getInputMaterializationMemory(int inputNum) {
return this.config.getLong(INPUT_DAM_MEMORY_PREFIX + inputNum, -1);
}

public void setBroadcastInputName(String name, int groupIndex) {
this.config.setString(BROADCAST_INPUT_NAME_PREFIX + groupIndex, name);
}

public String getBroadcastInputName(int groupIndex) {
return this.config.getString(BROADCAST_INPUT_NAME_PREFIX + groupIndex, String.format("broadcastVar%04d", groupIndex));
}

// --------------------------------------------------------------------------------------------
// Outputs
// --------------------------------------------------------------------------------------------
Expand Down

0 comments on commit 3992fc6

Please sign in to comment.