Skip to content

Commit

Permalink
[FLINK-15897][python] Defer the deserialization of the Python UDF exe…
Browse files Browse the repository at this point in the history
…cution results
  • Loading branch information
dianfu authored and hequn8128 committed Feb 6, 2020
1 parent c0268ea commit 5c1d7e4
Show file tree
Hide file tree
Showing 16 changed files with 168 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.python.env.PythonEnvironmentManager;
import org.apache.flink.util.Preconditions;
Expand Down Expand Up @@ -52,10 +50,9 @@
* An base class for {@link PythonFunctionRunner}.
*
* @param <IN> Type of the input elements.
* @param <OUT> Type of the execution results.
*/
@Internal
public abstract class AbstractPythonFunctionRunner<IN, OUT> implements PythonFunctionRunner<IN> {
public abstract class AbstractPythonFunctionRunner<IN> implements PythonFunctionRunner<IN> {

private static final String MAIN_INPUT_ID = "input";

Expand All @@ -64,7 +61,7 @@ public abstract class AbstractPythonFunctionRunner<IN, OUT> implements PythonFun
/**
* The Python function execution result receiver.
*/
private final FnDataReceiver<OUT> resultReceiver;
private final FnDataReceiver<byte[]> resultReceiver;

/**
* The Python execution environment manager.
Expand Down Expand Up @@ -110,21 +107,6 @@ public abstract class AbstractPythonFunctionRunner<IN, OUT> implements PythonFun
*/
private transient TypeSerializer<IN> inputTypeSerializer;

/**
* The TypeSerializer for execution results.
*/
private transient TypeSerializer<OUT> outputTypeSerializer;

/**
* Reusable InputStream used to holding the execution results to be deserialized.
*/
private transient ByteArrayInputStreamWithPos bais;

/**
* InputStream Wrapper.
*/
private transient DataInputViewStreamWrapper baisWrapper;

/**
* Reusable OutputStream used to holding the serialized input elements.
*/
Expand All @@ -143,7 +125,7 @@ public abstract class AbstractPythonFunctionRunner<IN, OUT> implements PythonFun

public AbstractPythonFunctionRunner(
String taskName,
FnDataReceiver<OUT> resultReceiver,
FnDataReceiver<byte[]> resultReceiver,
PythonEnvironmentManager environmentManager,
StateRequestHandler stateRequestHandler) {
this.taskName = Preconditions.checkNotNull(taskName);
Expand All @@ -154,12 +136,9 @@ public AbstractPythonFunctionRunner(

@Override
public void open() throws Exception {
bais = new ByteArrayInputStreamWithPos();
baisWrapper = new DataInputViewStreamWrapper(bais);
baos = new ByteArrayOutputStreamWithPos();
baosWrapper = new DataOutputViewStreamWrapper(baos);
inputTypeSerializer = getInputTypeSerializer();
outputTypeSerializer = getOutputTypeSerializer();

// The creation of stageBundleFactory depends on the initialized environment manager.
environmentManager.open();
Expand Down Expand Up @@ -217,10 +196,7 @@ public void startBundle() {
@SuppressWarnings("unchecked")
@Override
public FnDataReceiver<WindowedValue<byte[]>> create(String pCollectionId) {
return input -> {
bais.setBuffer(input.getValue(), 0, input.getValue().length);
resultReceiver.accept(outputTypeSerializer.deserialize(baisWrapper));
};
return input -> resultReceiver.accept(input.getValue());
}
};

Expand Down Expand Up @@ -284,10 +260,4 @@ protected RunnerApi.Environment createPythonExecutionEnvironment() throws Except
* Returns the TypeSerializer for input elements.
*/
public abstract TypeSerializer<IN> getInputTypeSerializer();

/**
* Returns the TypeSerializer for execution results.
*/
public abstract TypeSerializer<OUT> getOutputTypeSerializer();

}
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ public void processWatermark(Watermark mark) throws Exception {
/**
* Sends the execution results to the downstream operator.
*/
public abstract void emitResults();
public abstract void emitResults() throws IOException;

/**
* Reserves the memory used by the Python worker from the MemoryManager. This makes sure that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ConfigurationUtils;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.python.PythonConfig;
import org.apache.flink.python.PythonFunctionRunner;
import org.apache.flink.python.PythonOptions;
Expand All @@ -36,6 +38,7 @@
import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.runtime.runners.python.PythonScalarFunctionRunner;
import org.apache.flink.table.runtime.typeutils.PythonTypeUtils;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.utils.LegacyTypeInfoDataTypeConverter;
import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
Expand Down Expand Up @@ -106,6 +109,11 @@ public final class PythonScalarFunctionFlatMap
*/
private transient RowType udfOutputType;

/**
* The TypeSerializer for udf execution results.
*/
private transient TypeSerializer<Row> udfOutputTypeSerializer;

/**
* The queue holding the input elements for which the execution results have not been received.
*/
Expand All @@ -115,7 +123,17 @@ public final class PythonScalarFunctionFlatMap
* The queue holding the user-defined function execution results. The execution results are in
* the same order as the input elements.
*/
private transient LinkedBlockingQueue<Row> udfResultQueue;
private transient LinkedBlockingQueue<byte[]> udfResultQueue;

/**
* Reusable InputStream used to holding the execution results to be deserialized.
*/
private transient ByteArrayInputStreamWithPos bais;

/**
* InputStream Wrapper.
*/
private transient DataInputViewStreamWrapper baisWrapper;

/**
* The python config.
Expand Down Expand Up @@ -163,6 +181,7 @@ public PythonScalarFunctionFlatMap(
}

@Override
@SuppressWarnings("unchecked")
public void open(Configuration parameters) throws Exception {
super.open(parameters);

Expand Down Expand Up @@ -190,6 +209,7 @@ public void open(Configuration parameters) throws Exception {
.mapToObj(i -> inputType.getFields().get(i))
.collect(Collectors.toList()));
udfOutputType = new RowType(outputType.getFields().subList(forwardedFields.length, outputType.getFieldCount()));
udfOutputTypeSerializer = PythonTypeUtils.toFlinkTypeSerializer(udfOutputType);

RowTypeInfo forwardedInputTypeInfo = new RowTypeInfo(
Arrays.stream(forwardedFields)
Expand All @@ -200,6 +220,9 @@ public void open(Configuration parameters) throws Exception {
.toArray(TypeInformation[]::new));
forwardedInputSerializer = forwardedInputTypeInfo.createSerializer(getRuntimeContext().getExecutionConfig());

bais = new ByteArrayInputStreamWithPos();
baisWrapper = new DataInputViewStreamWrapper(bais);

this.pythonFunctionRunner = createPythonFunctionRunner();
this.pythonFunctionRunner.open();
}
Expand Down Expand Up @@ -251,7 +274,7 @@ private PythonEnv getPythonEnv() {
}

private PythonFunctionRunner<Row> createPythonFunctionRunner() throws IOException {
FnDataReceiver<Row> udfResultReceiver = input -> {
FnDataReceiver<byte[]> udfResultReceiver = input -> {
// handover to queue, do not block the result receiver thread
udfResultQueue.put(input);
};
Expand Down Expand Up @@ -288,10 +311,12 @@ private void bufferInput(Row input) {
forwardedInputQueue.add(forwardedFieldsRow);
}

private void emitResults() {
Row udfResult;
while ((udfResult = udfResultQueue.poll()) != null) {
private void emitResults() throws IOException {
byte[] rawUdfResult;
while ((rawUdfResult = udfResultQueue.poll()) != null) {
Row input = forwardedInputQueue.poll();
bais.setBuffer(rawUdfResult, 0, rawUdfResult.length);
Row udfResult = udfOutputTypeSerializer.deserialize(baisWrapper);
this.resultCollector.collect(Row.join(input, udfResult));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import org.apache.flink.annotation.Internal;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.python.PythonFunctionRunner;
import org.apache.flink.python.env.PythonEnvironmentManager;
import org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
Expand Down Expand Up @@ -62,10 +64,9 @@
* @param <IN> Type of the input elements.
* @param <OUT> Type of the output elements.
* @param <UDFIN> Type of the UDF input type.
* @param <UDFOUT> Type of the UDF input type.
*/
@Internal
public abstract class AbstractPythonScalarFunctionOperator<IN, OUT, UDFIN, UDFOUT>
public abstract class AbstractPythonScalarFunctionOperator<IN, OUT, UDFIN>
extends AbstractPythonFunctionOperator<IN, OUT> {

private static final long serialVersionUID = 1L;
Expand Down Expand Up @@ -114,7 +115,17 @@ public abstract class AbstractPythonScalarFunctionOperator<IN, OUT, UDFIN, UDFOU
* The queue holding the user-defined function execution results. The execution results are in
* the same order as the input elements.
*/
protected transient LinkedBlockingQueue<UDFOUT> udfResultQueue;
protected transient LinkedBlockingQueue<byte[]> udfResultQueue;

/**
* Reusable InputStream used to holding the execution results to be deserialized.
*/
protected transient ByteArrayInputStreamWithPos bais;

/**
* InputStream Wrapper.
*/
protected transient DataInputViewStreamWrapper baisWrapper;

AbstractPythonScalarFunctionOperator(
Configuration config,
Expand All @@ -140,6 +151,8 @@ public void open() throws Exception {
.mapToObj(i -> inputType.getFields().get(i))
.collect(Collectors.toList()));
udfOutputType = new RowType(outputType.getFields().subList(forwardedFields.length, outputType.getFieldCount()));
bais = new ByteArrayInputStreamWithPos();
baisWrapper = new DataInputViewStreamWrapper(bais);
super.open();
}

Expand All @@ -157,7 +170,7 @@ public PythonEnv getPythonEnv() {

@Override
public PythonFunctionRunner<IN> createPythonFunctionRunner() throws IOException {
final FnDataReceiver<UDFOUT> udfResultReceiver = input -> {
final FnDataReceiver<byte[]> udfResultReceiver = input -> {
// handover to queue, do not block the result receiver thread
udfResultQueue.put(input);
};
Expand All @@ -177,7 +190,7 @@ public PythonFunctionRunner<IN> createPythonFunctionRunner() throws IOException
public abstract UDFIN getUdfInput(IN element);

public abstract PythonFunctionRunner<UDFIN> createPythonFunctionRunner(
FnDataReceiver<UDFOUT> resultReceiver,
FnDataReceiver<byte[]> resultReceiver,
PythonEnvironmentManager pythonEnvironmentManager);

private class ProjectUdfInputPythonScalarFunctionRunner implements PythonFunctionRunner<IN> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.table.runtime.operators.python;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.python.PythonFunctionRunner;
import org.apache.flink.python.env.PythonEnvironmentManager;
Expand All @@ -34,11 +35,13 @@
import org.apache.flink.table.runtime.generated.GeneratedProjection;
import org.apache.flink.table.runtime.generated.Projection;
import org.apache.flink.table.runtime.runners.python.BaseRowPythonScalarFunctionRunner;
import org.apache.flink.table.runtime.typeutils.PythonTypeUtils;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Collector;

import org.apache.beam.sdk.fn.data.FnDataReceiver;

import java.io.IOException;
import java.util.Arrays;
import java.util.stream.Collectors;

Expand All @@ -47,7 +50,7 @@
*/
@Internal
public class BaseRowPythonScalarFunctionOperator
extends AbstractPythonScalarFunctionOperator<BaseRow, BaseRow, BaseRow, BaseRow> {
extends AbstractPythonScalarFunctionOperator<BaseRow, BaseRow, BaseRow> {

private static final long serialVersionUID = 1L;

Expand All @@ -71,6 +74,11 @@ public class BaseRowPythonScalarFunctionOperator
*/
private transient Projection<BaseRow, BinaryRow> udfInputProjection;

/**
* The TypeSerializer for udf execution results.
*/
private transient TypeSerializer<BaseRow> udfOutputTypeSerializer;

public BaseRowPythonScalarFunctionOperator(
Configuration config,
PythonFunctionInfo[] scalarFunctions,
Expand All @@ -82,13 +90,15 @@ public BaseRowPythonScalarFunctionOperator(
}

@Override
@SuppressWarnings("unchecked")
public void open() throws Exception {
super.open();
baseRowWrapper = new StreamRecordBaseRowWrappingCollector(output);
reuseJoinedRow = new JoinedRow();

udfInputProjection = createUdfInputProjection();
forwardedFieldProjection = createForwardedFieldProjection();
udfOutputTypeSerializer = PythonTypeUtils.toBlinkTypeSerializer(udfOutputType);
}

@Override
Expand All @@ -106,18 +116,20 @@ public BaseRow getUdfInput(BaseRow element) {

@Override
@SuppressWarnings("ConstantConditions")
public void emitResults() {
BaseRow udfResult;
while ((udfResult = udfResultQueue.poll()) != null) {
public void emitResults() throws IOException {
byte[] rawUdfResult;
while ((rawUdfResult = udfResultQueue.poll()) != null) {
BaseRow input = forwardedInputQueue.poll();
reuseJoinedRow.setHeader(input.getHeader());
bais.setBuffer(rawUdfResult, 0, rawUdfResult.length);
BaseRow udfResult = udfOutputTypeSerializer.deserialize(baisWrapper);
baseRowWrapper.collect(reuseJoinedRow.replace(input, udfResult));
}
}

@Override
public PythonFunctionRunner<BaseRow> createPythonFunctionRunner(
FnDataReceiver<BaseRow> resultReceiver,
FnDataReceiver<byte[]> resultReceiver,
PythonEnvironmentManager pythonEnvironmentManager) {
return new BaseRowPythonScalarFunctionRunner(
getRuntimeContext().getTaskName(),
Expand Down
Loading

0 comments on commit 5c1d7e4

Please sign in to comment.