From ab8470714715cb02b32e3299f36ea64ebe85695e Mon Sep 17 00:00:00 2001
From: zentol
Date: Wed, 20 Jan 2016 14:50:53 +0100
Subject: [PATCH] [FLINK-2501] [py] Remove the need to specify types for
transformations
Full changelog:
Major changes
===============
- Users no longer have to supply information about types
- Values are now stored as byte arrays on the Java side in
* a plain byte[] most of the time,
* a T2 within a join/cross
* a T2, b[]> within keyed operations.
- Every value contains information about its type at the beginning of
each byte array.
- Implemented KeySelectors
Minor
===============
- improved error messages in several places
- defaultable operations now use a "usesUDF" flag
- reshuffled type ID's; tuple type encoded as 1-25
- broadcast variables are now sent via the tcp socket
- ProjectJoin/-Cross now executes projection on python side
Java
---------------
- Sort field now stored as String, continuation of FLINK-2431
- object->byte[] serializer code moved into separate utility class
Python
---------------
- Fixed NullSerializer not taking a read method argument
- Serializer/Deserializer interface added
- Refactored DataSet structure
* Set and ReduceSet merged into DataSet
- configure() now takes an OperationInfo argument
- Simplified GroupReduce tests
- removed unused Function._open()
- simplified chaining setup
- most functions now use super.configure()
---
docs/apis/batch/dataset_transformations.md | 20 +-
docs/apis/batch/python.md | 82 ++--
.../flink/python/api/PythonOperationInfo.java | 51 +-
.../flink/python/api/PythonPlanBinder.java | 77 +--
.../python/api/functions/PythonCoGroup.java | 2 +-
.../api/functions/PythonMapPartition.java | 3 +-
.../{ => util}/IdentityGroupReduce.java | 5 +-
.../api/functions/util/KeyDiscarder.java | 29 ++
.../functions/util/NestedKeyDiscarder.java | 30 ++
.../api/functions/util/SerializerMap.java | 32 ++
.../functions/util/StringDeserializerMap.java | 26 +
.../util/StringTupleDeserializerMap.java | 27 ++
.../api/streaming/data/PythonReceiver.java | 185 ++-----
.../api/streaming/data/PythonSender.java | 264 ++--------
.../api/streaming/data/PythonStreamer.java | 76 +--
.../streaming/plan/PythonPlanReceiver.java | 199 ++++++--
.../api/streaming/plan/PythonPlanSender.java | 87 +---
.../streaming/util/SerializationUtils.java | 283 +++++++++++
.../python/api/flink/connection/Collector.py | 181 +++----
.../python/api/flink/connection/Connection.py | 8 +
.../python/api/flink/connection/Constants.py | 23 +-
.../python/api/flink/connection/Iterator.py | 194 +++++---
.../python/api/flink/example/TPCHQuery10.py | 2 +-
.../python/api/flink/example/TPCHQuery3.py | 4 +-
.../api/flink/example/TriangleEnumeration.py | 14 +-
.../api/flink/example/WebLogAnalysis.py | 10 +-
.../python/api/flink/example/WordCount.py | 5 +-
.../api/flink/functions/CoGroupFunction.py | 7 +-
.../python/api/flink/functions/Function.py | 57 +--
.../flink/functions/GroupReduceFunction.py | 15 +-
.../flink/functions/KeySelectorFunction.py | 28 ++
.../api/flink/functions/ReduceFunction.py | 15 +-
.../flink/python/api/flink/plan/Constants.py | 15 +-
.../flink/python/api/flink/plan/DataSet.py | 451 ++++++++++--------
.../python/api/flink/plan/Environment.py | 22 +-
.../python/api/flink/plan/OperationInfo.py | 4 +
.../org/apache/flink/python/api/test_main.py | 65 ++-
.../org/apache/flink/python/api/test_main2.py | 38 +-
.../flink/python/api/test_type_deduction.py | 73 ---
39 files changed, 1426 insertions(+), 1283 deletions(-)
rename flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/{ => util}/IdentityGroupReduce.java (91%)
create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/KeyDiscarder.java
create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/NestedKeyDiscarder.java
create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/SerializerMap.java
create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringDeserializerMap.java
create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringTupleDeserializerMap.java
create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/util/SerializationUtils.java
create mode 100644 flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/KeySelectorFunction.py
delete mode 100644 flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_type_deduction.py
diff --git a/docs/apis/batch/dataset_transformations.md b/docs/apis/batch/dataset_transformations.md
index b9d7f0c3929da..7315ec2eaf5af 100644
--- a/docs/apis/batch/dataset_transformations.md
+++ b/docs/apis/batch/dataset_transformations.md
@@ -71,7 +71,7 @@ val intSums = intPairs.map { pair => pair._1 + pair._2 }
~~~python
- intSums = intPairs.map(lambda x: sum(x), INT)
+ intSums = intPairs.map(lambda x: sum(x))
~~~
@@ -115,7 +115,7 @@ val words = textLines.flatMap { _.split(" ") }
~~~python
- words = lines.flat_map(lambda x,c: [line.split() for line in x], STRING)
+ words = lines.flat_map(lambda x,c: [line.split() for line in x])
~~~
@@ -163,7 +163,7 @@ val counts = texLines.mapPartition { in => Some(in.size) }
~~~python
- counts = lines.map_partition(lambda x,c: [sum(1 for _ in x)], INT)
+ counts = lines.map_partition(lambda x,c: [sum(1 for _ in x)])
~~~
@@ -459,7 +459,7 @@ Works analogous to grouping by Case Class fields in *Reduce* transformations.
for key in dic.keys():
collector.collect(key)
- output = data.group_by(0).reduce_group(DistinctReduce(), STRING)
+ output = data.group_by(0).reduce_group(DistinctReduce())
~~~
@@ -539,7 +539,7 @@ val output = input.groupBy(0).sortGroup(1, Order.ASCENDING).reduceGroup {
for key in dic.keys():
collector.collect(key)
- output = data.group_by(0).sort_group(1, Order.ASCENDING).reduce_group(DistinctReduce(), STRING)
+ output = data.group_by(0).sort_group(1, Order.ASCENDING).reduce_group(DistinctReduce())
~~~
@@ -644,7 +644,7 @@ class MyCombinableGroupReducer
def combine(self, iterator, collector):
return self.reduce(iterator, collector)
-data.reduce_group(GroupReduce(), (STRING, INT, FLOAT), combinable=True)
+data.reduce_group(GroupReduce(), combinable=True)
~~~
@@ -864,7 +864,7 @@ val output = input.reduceGroup(new MyGroupReducer())
~~~python
- output = data.reduce_group(MyGroupReducer(), ... )
+ output = data.reduce_group(MyGroupReducer())
~~~
@@ -1223,7 +1223,7 @@ val weightedRatings = ratings.join(weights).where("category").equalTo(0) {
weightedRatings =
ratings.join(weights).where(0).equal_to(0). \
- with(new PointWeighter(), (STRING, FLOAT));
+ with(new PointWeighter());
~~~
@@ -1719,7 +1719,7 @@ val distances = coords1.cross(coords2) {
def cross(self, c1, c2):
return (c1[0], c2[0], sqrt(pow(c1[1] - c2.[1], 2) + pow(c1[2] - c2[2], 2)))
- distances = coords1.cross(coords2).using(Euclid(), (INT,INT,FLOAT))
+ distances = coords1.cross(coords2).using(Euclid())
~~~
#### Cross with Projection
@@ -1878,7 +1878,7 @@ val output = iVals.coGroup(dVals).where(0).equalTo(0) {
collector.collect(value[1] * i)
- output = ivals.co_group(dvals).where(0).equal_to(0).using(CoGroup(), DOUBLE)
+ output = ivals.co_group(dvals).where(0).equal_to(0).using(CoGroup())
~~~
diff --git a/docs/apis/batch/python.md b/docs/apis/batch/python.md
index 74da97e933774..7e2060bc910d6 100644
--- a/docs/apis/batch/python.md
+++ b/docs/apis/batch/python.md
@@ -50,7 +50,6 @@ to run it locally.
{% highlight python %}
from flink.plan.Environment import get_environment
-from flink.plan.Constants import INT, STRING
from flink.functions.GroupReduceFunction import GroupReduceFunction
class Adder(GroupReduceFunction):
@@ -59,18 +58,17 @@ class Adder(GroupReduceFunction):
count += sum([x[0] for x in iterator])
collector.collect((count, word))
-if __name__ == "__main__":
- env = get_environment()
- data = env.from_elements("Who's there?",
- "I think I hear them. Stand, ho! Who's there?")
+env = get_environment()
+data = env.from_elements("Who's there?",
+ "I think I hear them. Stand, ho! Who's there?")
- data \
- .flat_map(lambda x, c: [(1, word) for word in x.lower().split()], (INT, STRING)) \
- .group_by(1) \
- .reduce_group(Adder(), (INT, STRING), combinable=True) \
- .output()
+data \
+ .flat_map(lambda x, c: [(1, word) for word in x.lower().split()]) \
+ .group_by(1) \
+ .reduce_group(Adder(), combinable=True) \
+ .output()
- env.execute(local=True)
+env.execute(local=True)
{% endhighlight %}
{% top %}
@@ -78,8 +76,8 @@ if __name__ == "__main__":
Program Skeleton
----------------
-As we already saw in the example, Flink programs look like regular python
-programs with a `if __name__ == "__main__":` block. Each program consists of the same basic parts:
+As we already saw in the example, Flink programs look like regular python programs.
+Each program consists of the same basic parts:
1. Obtain an `Environment`,
2. Load/create the initial data,
@@ -117,7 +115,7 @@ methods on DataSet with your own custom transformation function. For example,
a map transformation looks like this:
{% highlight python %}
-data.map(lambda x: x*2, INT)
+data.map(lambda x: x*2)
{% endhighlight %}
This will create a new DataSet by doubling every value in the original DataSet.
@@ -197,7 +195,7 @@ examples.
Takes one element and produces one element.
{% highlight python %}
-data.map(lambda x: x * 2, INT)
+data.map(lambda x: x * 2)
{% endhighlight %}
@@ -208,8 +206,7 @@ data.map(lambda x: x * 2, INT)
Takes one element and produces zero, one, or more elements.
{% highlight python %}
data.flat_map(
- lambda x,c: [(1,word) for word in line.lower().split() for line in x],
- (INT, STRING))
+ lambda x,c: [(1,word) for word in line.lower().split() for line in x])
{% endhighlight %}
@@ -221,7 +218,7 @@ data.flat_map(
as an `Iterator` and can produce an arbitrary number of result values. The number of
elements in each partition depends on the degree-of-parallelism and previous operations.
{% highlight python %}
-data.map_partition(lambda x,c: [value * 2 for value in x], INT)
+data.map_partition(lambda x,c: [value * 2 for value in x])
{% endhighlight %}
@@ -260,7 +257,7 @@ class Adder(GroupReduceFunction):
count += sum([x[0] for x in iterator)
collector.collect((count, word))
-data.reduce_group(Adder(), (INT, STRING))
+data.reduce_group(Adder())
{% endhighlight %}
@@ -392,24 +389,33 @@ They are also the only way to define an optional `combine` function for a reduce
Lambda functions allow the easy insertion of one-liners. Note that a lambda function has to return
an iterable, if the operation can return multiple values. (All functions receiving a collector argument)
-Flink requires type information at the time when it prepares the program for execution
-(when the main method of the program is called). This is done by passing an exemplary
-object that has the desired type. This holds also for tuples.
+{% top %}
+
+Data Types
+----------
+
+Flink's Python API currently only offers native support for primitive python types (int, float, bool, string) and byte arrays.
+The type support can be extended by passing a serializer, deserializer and type class to the environment.
{% highlight python %}
-(INT, STRING)
-{% endhighlight %}
+class MyObj(object):
+ def __init__(self, i):
+ self.value = i
-Would denote a tuple containing an int and a string. Note that for Operations that work strictly on tuples (like cross), no braces are required.
-There are a few Constants defined in flink.plan.Constants that allow this in a more readable fashion.
+class MySerializer(object):
+ def serialize(self, value):
+ return struct.pack(">i", value.value)
-{% top %}
-Data Types
-----------
+class MyDeserializer(object):
+ def _deserialize(self, read):
+ i = struct.unpack(">i", read(4))[0]
+ return MyObj(i)
+
-Flink's Python API currently only supports primitive python types (int, float, bool, string) and byte arrays.
+env.register_custom_type(MyObj, MySerializer(), MyDeserializer())
+{% endhighlight %}
#### Tuples/Lists
@@ -419,7 +425,7 @@ a fix number of fields of various types (up to 25). Every field of a tuple can b
{% highlight python %}
word_counts = env.from_elements(("hello", 1), ("world",2))
-counts = word_counts.map(lambda x: x[1], INT)
+counts = word_counts.map(lambda x: x[1])
{% endhighlight %}
When working with operators that require a Key for grouping or matching records,
@@ -455,16 +461,16 @@ Collection-based:
{% highlight python %}
env = get_environment
-# read text file from local files system
+\# read text file from local files system
localLiens = env.read_text("file:#/path/to/my/textfile")
- read text file from a HDFS running at nnHost:nnPort
+\# read text file from a HDFS running at nnHost:nnPort
hdfsLines = env.read_text("hdfs://nnHost:nnPort/path/to/my/textfile")
- read a CSV file with three fields
+\# read a CSV file with three fields, schema defined using constants defined in flink.plan.Constants
csvInput = env.read_csv("hdfs:///the/CSV/file", (INT, STRING, DOUBLE))
- create a set from some given elements
+\# create a set from some given elements
values = env.from_elements("Foo", "bar", "foobar", "fubar")
{% endhighlight %}
@@ -530,7 +536,7 @@ toBroadcast = env.from_elements(1, 2, 3)
data = env.from_elements("a", "b")
# 2. Broadcast the DataSet
-data.map(MapperBcv(), INT).with_broadcast_set("bcv", toBroadcast)
+data.map(MapperBcv()).with_broadcast_set("bcv", toBroadcast)
{% endhighlight %}
Make sure that the names (`bcv` in the previous example) match when registering and
@@ -568,9 +574,9 @@ execution environment as follows:
env = get_environment()
env.set_degree_of_parallelism(3)
-text.flat_map(lambda x,c: x.lower().split(), (INT, STRING)) \
+text.flat_map(lambda x,c: x.lower().split()) \
.group_by(1) \
- .reduce_group(Adder(), (INT, STRING), combinable=True) \
+ .reduce_group(Adder(), combinable=True) \
.output()
env.execute()
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonOperationInfo.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonOperationInfo.java
index 6ecd6837e0bfd..30a7133f2a2d7 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonOperationInfo.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonOperationInfo.java
@@ -32,10 +32,9 @@ public class PythonOperationInfo {
public String[] keys2; //join/cogroup keys
public TypeInformation> types; //typeinformation about output type
public AggregationEntry[] aggregates;
- public ProjectionEntry[] projections; //projectFirst/projectSecond
public Object[] values;
public int count;
- public int field;
+ public String field;
public int[] fields;
public Order order;
public String path;
@@ -46,6 +45,7 @@ public class PythonOperationInfo {
public WriteMode writeMode;
public boolean toError;
public String name;
+ public boolean usesUDF;
public PythonOperationInfo(PythonPlanStreamer streamer, Operation identifier) throws IOException {
Object tmpType;
@@ -127,7 +127,7 @@ public PythonOperationInfo(PythonPlanStreamer streamer, Operation identifier) th
case REBALANCE:
return;
case SORT:
- field = (Integer) streamer.getRecord(true);
+ field = "f0.f" + (Integer) streamer.getRecord(true);
int encodedOrder = (Integer) streamer.getRecord(true);
switch (encodedOrder) {
case 0:
@@ -162,15 +162,9 @@ public PythonOperationInfo(PythonPlanStreamer streamer, Operation identifier) th
case CROSS_H:
case CROSS_T:
otherID = (Integer) streamer.getRecord(true);
+ usesUDF = (Boolean) streamer.getRecord();
tmpType = streamer.getRecord();
types = tmpType == null ? null : getForObject(tmpType);
- int cProjectCount = (Integer) streamer.getRecord(true);
- projections = new ProjectionEntry[cProjectCount];
- for (int x = 0; x < cProjectCount; x++) {
- String side = (String) streamer.getRecord();
- int[] keys = toIntArray((Tuple) streamer.getRecord(true));
- projections[x] = new ProjectionEntry(ProjectionSide.valueOf(side.toUpperCase()), keys);
- }
name = (String) streamer.getRecord();
return;
case REDUCE:
@@ -185,15 +179,9 @@ public PythonOperationInfo(PythonPlanStreamer streamer, Operation identifier) th
keys1 = normalizeKeys(streamer.getRecord(true));
keys2 = normalizeKeys(streamer.getRecord(true));
otherID = (Integer) streamer.getRecord(true);
+ usesUDF = (Boolean) streamer.getRecord();
tmpType = streamer.getRecord();
types = tmpType == null ? null : getForObject(tmpType);
- int jProjectCount = (Integer) streamer.getRecord(true);
- projections = new ProjectionEntry[jProjectCount];
- for (int x = 0; x < jProjectCount; x++) {
- String side = (String) streamer.getRecord();
- int[] keys = toIntArray((Tuple) streamer.getRecord(true));
- projections[x] = new ProjectionEntry(ProjectionSide.valueOf(side.toUpperCase()), keys);
- }
name = (String) streamer.getRecord();
return;
case MAPPARTITION:
@@ -221,7 +209,6 @@ public String toString() {
sb.append("Keys2: ").append(Arrays.toString(keys2)).append("\n");
sb.append("Keys: ").append(Arrays.toString(keys)).append("\n");
sb.append("Aggregates: ").append(Arrays.toString(aggregates)).append("\n");
- sb.append("Projections: ").append(Arrays.toString(projections)).append("\n");
sb.append("Count: ").append(count).append("\n");
sb.append("Field: ").append(field).append("\n");
sb.append("Order: ").append(order.toString()).append("\n");
@@ -260,26 +247,6 @@ public String toString() {
}
}
- public static class ProjectionEntry {
- public ProjectionSide side;
- public int[] keys;
-
- public ProjectionEntry(ProjectionSide side, int[] keys) {
- this.side = side;
- this.keys = keys;
- }
-
- @Override
- public String toString() {
- return side + " - " + Arrays.toString(keys);
- }
- }
-
- public enum ProjectionSide {
- FIRST,
- SECOND
- }
-
public enum DatasizeHint {
NONE,
TINY,
@@ -296,24 +263,24 @@ private static String[] normalizeKeys(Object keys) {
if (tupleKeys.getField(0) instanceof Integer) {
String[] stringKeys = new String[tupleKeys.getArity()];
for (int x = 0; x < stringKeys.length; x++) {
- stringKeys[x] = "f" + (Integer) tupleKeys.getField(x);
+ stringKeys[x] = "f0.f" + (Integer) tupleKeys.getField(x);
}
return stringKeys;
}
if (tupleKeys.getField(0) instanceof String) {
return tupleToStringArray(tupleKeys);
}
- throw new RuntimeException("Key argument contains field that is neither an int nor a String.");
+ throw new RuntimeException("Key argument contains field that is neither an int nor a String: " + tupleKeys);
}
if (keys instanceof int[]) {
int[] intKeys = (int[]) keys;
String[] stringKeys = new String[intKeys.length];
for (int x = 0; x < stringKeys.length; x++) {
- stringKeys[x] = "f" + intKeys[x];
+ stringKeys[x] = "f0.f" + intKeys[x];
}
return stringKeys;
}
- throw new RuntimeException("Key argument is neither an int[] nor a Tuple.");
+ throw new RuntimeException("Key argument is neither an int[] nor a Tuple: " + keys.toString());
}
private static int[] toIntArray(Object key) {
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java
index 2e64a56349f7f..f07a975a319e1 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java
@@ -28,10 +28,7 @@
import org.apache.flink.api.java.operators.AggregateOperator;
import org.apache.flink.api.java.operators.CoGroupRawOperator;
import org.apache.flink.api.java.operators.CrossOperator.DefaultCross;
-import org.apache.flink.api.java.operators.CrossOperator.ProjectCross;
import org.apache.flink.api.java.operators.Grouping;
-import org.apache.flink.api.java.operators.JoinOperator.DefaultJoin;
-import org.apache.flink.api.java.operators.JoinOperator.ProjectJoin;
import org.apache.flink.api.java.operators.Keys;
import org.apache.flink.api.java.operators.SortedGrouping;
import org.apache.flink.api.java.operators.UdfOperator;
@@ -42,14 +39,18 @@
import org.apache.flink.configuration.GlobalConfiguration;
import org.apache.flink.core.fs.FileSystem;
import org.apache.flink.core.fs.Path;
+import org.apache.flink.python.api.functions.util.NestedKeyDiscarder;
+import org.apache.flink.python.api.functions.util.StringTupleDeserializerMap;
import org.apache.flink.python.api.PythonOperationInfo.DatasizeHint;
import static org.apache.flink.python.api.PythonOperationInfo.DatasizeHint.HUGE;
import static org.apache.flink.python.api.PythonOperationInfo.DatasizeHint.NONE;
import static org.apache.flink.python.api.PythonOperationInfo.DatasizeHint.TINY;
-import org.apache.flink.python.api.PythonOperationInfo.ProjectionEntry;
import org.apache.flink.python.api.functions.PythonCoGroup;
-import org.apache.flink.python.api.functions.IdentityGroupReduce;
+import org.apache.flink.python.api.functions.util.IdentityGroupReduce;
import org.apache.flink.python.api.functions.PythonMapPartition;
+import org.apache.flink.python.api.functions.util.KeyDiscarder;
+import org.apache.flink.python.api.functions.util.SerializerMap;
+import org.apache.flink.python.api.functions.util.StringDeserializerMap;
import org.apache.flink.python.api.streaming.plan.PythonPlanStreamer;
import org.apache.flink.runtime.filecache.FileCache;
import org.slf4j.Logger;
@@ -410,34 +411,35 @@ private void createCsvSource(PythonOperationInfo info) throws IOException {
sets.put(info.setID, env.createInput(new TupleCsvInputFormat(new Path(info.path),
info.lineDelimiter, info.fieldDelimiter, (TupleTypeInfo) info.types), info.types)
- .name("CsvSource"));
+ .name("CsvSource").map(new SerializerMap()).name("CsvSourcePostStep"));
}
private void createTextSource(PythonOperationInfo info) throws IOException {
- sets.put(info.setID, env.readTextFile(info.path).name("TextSource"));
+ sets.put(info.setID, env.readTextFile(info.path).name("TextSource").map(new SerializerMap()).name("TextSourcePostStep"));
}
private void createValueSource(PythonOperationInfo info) throws IOException {
- sets.put(info.setID, env.fromElements(info.values).name("ValueSource"));
+ sets.put(info.setID, env.fromElements(info.values).name("ValueSource").map(new SerializerMap()).name("ValueSourcePostStep"));
}
private void createSequenceSource(PythonOperationInfo info) throws IOException {
- sets.put(info.setID, env.generateSequence(info.from, info.to).name("SequenceSource"));
+ sets.put(info.setID, env.generateSequence(info.from, info.to).name("SequenceSource").map(new SerializerMap()).name("SequenceSourcePostStep"));
}
private void createCsvSink(PythonOperationInfo info) throws IOException {
DataSet parent = (DataSet) sets.get(info.parentID);
- parent.writeAsCsv(info.path, info.lineDelimiter, info.fieldDelimiter, info.writeMode).name("CsvSink");
+ parent.map(new StringTupleDeserializerMap()).name("CsvSinkPreStep")
+ .writeAsCsv(info.path, info.lineDelimiter, info.fieldDelimiter, info.writeMode).name("CsvSink");
}
private void createTextSink(PythonOperationInfo info) throws IOException {
DataSet parent = (DataSet) sets.get(info.parentID);
- parent.writeAsText(info.path, info.writeMode).name("TextSink");
+ parent.map(new StringDeserializerMap()).writeAsText(info.path, info.writeMode).name("TextSink");
}
private void createPrintSink(PythonOperationInfo info) throws IOException {
DataSet parent = (DataSet) sets.get(info.parentID);
- parent.output(new PrintingOutputFormat(info.toError));
+ parent.map(new StringDeserializerMap()).name("PrintSinkPreStep").output(new PrintingOutputFormat(info.toError));
}
private void createBroadcastVariable(PythonOperationInfo info) throws IOException {
@@ -471,7 +473,7 @@ private void createAggregationOperation(PythonOperationInfo info) throws IOExcep
private void createDistinctOperation(PythonOperationInfo info) throws IOException {
DataSet op = (DataSet) sets.get(info.parentID);
- sets.put(info.setID, info.keys.length == 0 ? op.distinct() : op.distinct(info.keys).name("Distinct"));
+ sets.put(info.setID, info.keys.length == 0 ? op.distinct() : op.distinct(info.keys).name("Distinct").map(new KeyDiscarder()).name("DistinctPostStep"));
}
private void createFirstOperation(PythonOperationInfo info) throws IOException {
@@ -486,7 +488,7 @@ private void createGroupOperation(PythonOperationInfo info) throws IOException {
private void createHashPartitionOperation(PythonOperationInfo info) throws IOException {
DataSet op1 = (DataSet) sets.get(info.parentID);
- sets.put(info.setID, op1.partitionByHash(info.keys));
+ sets.put(info.setID, op1.partitionByHash(info.keys).map(new KeyDiscarder()).name("HashPartitionPostStep"));
}
private void createProjectOperation(PythonOperationInfo info) throws IOException {
@@ -546,23 +548,10 @@ private void createCrossOperation(DatasizeHint mode, PythonOperationInfo info) {
default:
throw new IllegalArgumentException("Invalid Cross mode specified: " + mode);
}
- if (info.types != null && (info.projections == null || info.projections.length == 0)) {
+ if (info.usesUDF) {
sets.put(info.setID, defaultResult.mapPartition(new PythonMapPartition(info.setID, info.types)).name(info.name));
- } else if (info.projections.length == 0) {
- sets.put(info.setID, defaultResult.name("DefaultCross"));
} else {
- ProjectCross project = null;
- for (ProjectionEntry pe : info.projections) {
- switch (pe.side) {
- case FIRST:
- project = project == null ? defaultResult.projectFirst(pe.keys) : project.projectFirst(pe.keys);
- break;
- case SECOND:
- project = project == null ? defaultResult.projectSecond(pe.keys) : project.projectSecond(pe.keys);
- break;
- }
- }
- sets.put(info.setID, project.name("ProjectCross"));
+ sets.put(info.setID, defaultResult.name("DefaultCross"));
}
}
@@ -616,38 +605,22 @@ private void createJoinOperation(DatasizeHint mode, PythonOperationInfo info) {
DataSet op1 = (DataSet) sets.get(info.parentID);
DataSet op2 = (DataSet) sets.get(info.otherID);
- if (info.types != null && (info.projections == null || info.projections.length == 0)) {
- sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode).name("PythonJoinPreStep")
+ if (info.usesUDF) {
+ sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode)
.mapPartition(new PythonMapPartition(info.setID, info.types)).name(info.name));
} else {
- DefaultJoin defaultResult = createDefaultJoin(op1, op2, info.keys1, info.keys2, mode);
- if (info.projections.length == 0) {
- sets.put(info.setID, defaultResult.name("DefaultJoin"));
- } else {
- ProjectJoin project = null;
- for (ProjectionEntry pe : info.projections) {
- switch (pe.side) {
- case FIRST:
- project = project == null ? defaultResult.projectFirst(pe.keys) : project.projectFirst(pe.keys);
- break;
- case SECOND:
- project = project == null ? defaultResult.projectSecond(pe.keys) : project.projectSecond(pe.keys);
- break;
- }
- }
- sets.put(info.setID, project.name("ProjectJoin"));
- }
+ sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode));
}
}
- private DefaultJoin createDefaultJoin(DataSet op1, DataSet op2, String[] firstKeys, String[] secondKeys, DatasizeHint mode) {
+ private DataSet createDefaultJoin(DataSet op1, DataSet op2, String[] firstKeys, String[] secondKeys, DatasizeHint mode) {
switch (mode) {
case NONE:
- return op1.join(op2).where(firstKeys).equalTo(secondKeys);
+ return op1.join(op2).where(firstKeys).equalTo(secondKeys).map(new NestedKeyDiscarder()).name("DefaultJoinPostStep");
case HUGE:
- return op1.joinWithHuge(op2).where(firstKeys).equalTo(secondKeys);
+ return op1.joinWithHuge(op2).where(firstKeys).equalTo(secondKeys).map(new NestedKeyDiscarder()).name("DefaultJoinPostStep");
case TINY:
- return op1.joinWithTiny(op2).where(firstKeys).equalTo(secondKeys);
+ return op1.joinWithTiny(op2).where(firstKeys).equalTo(secondKeys).map(new NestedKeyDiscarder()).name("DefaultJoinPostStep");
default:
throw new IllegalArgumentException("Invalid join mode specified.");
}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonCoGroup.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonCoGroup.java
index 2349aa919f838..33d88c3182d76 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonCoGroup.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonCoGroup.java
@@ -33,7 +33,7 @@ public class PythonCoGroup extends RichCoGroupFunction typeInformation) {
this.typeInformation = typeInformation;
- streamer = new PythonStreamer(this, id);
+ streamer = new PythonStreamer(this, id, true);
}
/**
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonMapPartition.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonMapPartition.java
index 50b2cf4667e2c..6282210c4863d 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonMapPartition.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonMapPartition.java
@@ -14,6 +14,7 @@
import java.io.IOException;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.configuration.Configuration;
@@ -33,7 +34,7 @@ public class PythonMapPartition extends RichMapPartitionFunction typeInformation) {
this.typeInformation = typeInformation;
- streamer = new PythonStreamer(this, id);
+ streamer = new PythonStreamer(this, id, typeInformation instanceof PrimitiveArrayTypeInfo);
}
/**
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/IdentityGroupReduce.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/IdentityGroupReduce.java
similarity index 91%
rename from flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/IdentityGroupReduce.java
rename to flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/IdentityGroupReduce.java
index a4201532526d1..1e7bbe6e260b9 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/IdentityGroupReduce.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/IdentityGroupReduce.java
@@ -10,11 +10,14 @@
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
-package org.apache.flink.python.api.functions;
+package org.apache.flink.python.api.functions.util;
import org.apache.flink.util.Collector;
import org.apache.flink.api.common.functions.GroupReduceFunction;
+/*
+Utility function to group and sort data.
+*/
public class IdentityGroupReduce implements GroupReduceFunction {
@Override
public final void reduce(Iterable values, Collector out) throws Exception {
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/KeyDiscarder.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/KeyDiscarder.java
new file mode 100644
index 0000000000000..118bc9f7976a7
--- /dev/null
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/KeyDiscarder.java
@@ -0,0 +1,29 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
+ * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.apache.flink.python.api.functions.util;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple2;
+
+/*
+Utility function to extract the value from a Key-Value Tuple.
+*/
+@ForwardedFields("f1->*")
+public class KeyDiscarder implements MapFunction, byte[]> {
+ @Override
+ public byte[] map(Tuple2 value) throws Exception {
+ return value.f1;
+ }
+}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/NestedKeyDiscarder.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/NestedKeyDiscarder.java
new file mode 100644
index 0000000000000..d59eb73c0b90f
--- /dev/null
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/NestedKeyDiscarder.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
+ * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.apache.flink.python.api.functions.util;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple2;
+
+/*
+Utility function to extract values from 2 Key-Value Tuples after a DefaultJoin.
+*/
+@ForwardedFields("f0.f1->f0; f1.f1->f1")
+public class NestedKeyDiscarder implements MapFunction> {
+ @Override
+ public Tuple2 map(IN value) throws Exception {
+ Tuple2, Tuple2> x = (Tuple2, Tuple2>) value;
+ return new Tuple2<>(x.f0.f1, x.f1.f1);
+ }
+}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/SerializerMap.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/SerializerMap.java
new file mode 100644
index 0000000000000..fba83f9dfe763
--- /dev/null
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/SerializerMap.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
+ * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.apache.flink.python.api.functions.util;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.python.api.streaming.util.SerializationUtils;
+import org.apache.flink.python.api.streaming.util.SerializationUtils.Serializer;
+
+/*
+Utility function to serialize values, usually directly from data sources.
+*/
+public class SerializerMap implements MapFunction {
+ private Serializer serializer = null;
+
+ @Override
+ public byte[] map(IN value) throws Exception {
+ if (serializer == null) {
+ serializer = SerializationUtils.getSerializer(value);
+ }
+ return serializer.serialize(value);
+ }
+}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringDeserializerMap.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringDeserializerMap.java
new file mode 100644
index 0000000000000..d89fc41322eb0
--- /dev/null
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringDeserializerMap.java
@@ -0,0 +1,26 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
+ * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.apache.flink.python.api.functions.util;
+
+import org.apache.flink.api.common.functions.MapFunction;
+
+/*
+Utility function to deserialize strings, used for non-CSV sinks.
+*/
+public class StringDeserializerMap implements MapFunction {
+ @Override
+ public String map(byte[] value) throws Exception {
+ //discard type byte and size
+ return new String(value, 5, value.length - 5);
+ }
+}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringTupleDeserializerMap.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringTupleDeserializerMap.java
new file mode 100644
index 0000000000000..b6d60e1e7a3b1
--- /dev/null
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringTupleDeserializerMap.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
+ * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.apache.flink.python.api.functions.util;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple1;
+
+/*
+Utility function to deserialize strings, used for CSV sinks.
+*/
+public class StringTupleDeserializerMap implements MapFunction> {
+ @Override
+ public Tuple1 map(byte[] value) throws Exception {
+ //5 = string type byte + string size
+ return new Tuple1<>(new String(value, 5, value.length - 5));
+ }
+}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonReceiver.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonReceiver.java
index 9ed047cb6e349..3ee0fde616d19 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonReceiver.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonReceiver.java
@@ -20,24 +20,13 @@
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple2;
import static org.apache.flink.python.api.PythonPlanBinder.FLINK_TMP_DATA_DIR;
import static org.apache.flink.python.api.PythonPlanBinder.MAPPED_FILE_SIZE;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BOOLEAN;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTE;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTES;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_DOUBLE;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_FLOAT;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_INTEGER;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_LONG;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_NULL;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_SHORT;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_STRING;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_TUPLE;
-import org.apache.flink.python.api.types.CustomTypeWrapper;
import org.apache.flink.util.Collector;
/**
- * General-purpose class to read data from memory-mapped files.
+ * This class is used to read data from memory-mapped files.
*/
public class PythonReceiver implements Serializable {
private static final long serialVersionUID = -2474088929850009968L;
@@ -47,11 +36,18 @@ public class PythonReceiver implements Serializable {
private FileChannel inputChannel;
private MappedByteBuffer fileBuffer;
+ private final boolean readAsByteArray;
+
private Deserializer> deserializer = null;
+ public PythonReceiver(boolean usesByteArray) {
+ readAsByteArray = usesByteArray;
+ }
+
//=====Setup========================================================================================================
public void open(String path) throws IOException {
setupMappedFile(path);
+ deserializer = readAsByteArray ? new ByteArrayDeserializer() : new TupleDeserializer();
}
private void setupMappedFile(String inputFilePath) throws FileNotFoundException, IOException {
@@ -81,6 +77,8 @@ private void closeMappedFile() throws IOException {
inputRAF.close();
}
+
+ //=====IO===========================================================================================================
/**
* Reads a buffer of the given size from the memory-mapped file, and collects all records contained. This method
* assumes that all values in the buffer are of the same type. This method does NOT take care of synchronization.
@@ -94,173 +92,46 @@ private void closeMappedFile() throws IOException {
public void collectBuffer(Collector c, int bufferSize) throws IOException {
fileBuffer.position(0);
- if (deserializer == null) {
- byte type = fileBuffer.get();
- deserializer = getDeserializer(type);
- }
while (fileBuffer.position() < bufferSize) {
c.collect(deserializer.deserialize());
}
}
//=====Deserializer=================================================================================================
- private Deserializer> getDeserializer(byte type) {
- switch (type) {
- case TYPE_TUPLE:
- return new TupleDeserializer();
- case TYPE_BOOLEAN:
- return new BooleanDeserializer();
- case TYPE_BYTE:
- return new ByteDeserializer();
- case TYPE_BYTES:
- return new BytesDeserializer();
- case TYPE_SHORT:
- return new ShortDeserializer();
- case TYPE_INTEGER:
- return new IntDeserializer();
- case TYPE_LONG:
- return new LongDeserializer();
- case TYPE_STRING:
- return new StringDeserializer();
- case TYPE_FLOAT:
- return new FloatDeserializer();
- case TYPE_DOUBLE:
- return new DoubleDeserializer();
- case TYPE_NULL:
- return new NullDeserializer();
- default:
- return new CustomTypeDeserializer(type);
-
- }
- }
-
private interface Deserializer {
public T deserialize();
}
- private class CustomTypeDeserializer implements Deserializer {
- private final byte type;
-
- public CustomTypeDeserializer(byte type) {
- this.type = type;
- }
-
- @Override
- public CustomTypeWrapper deserialize() {
- int size = fileBuffer.getInt();
- byte[] data = new byte[size];
- fileBuffer.get(data);
- return new CustomTypeWrapper(type, data);
- }
- }
-
- private class BooleanDeserializer implements Deserializer {
- @Override
- public Boolean deserialize() {
- return fileBuffer.get() == 1;
- }
- }
-
- private class ByteDeserializer implements Deserializer {
- @Override
- public Byte deserialize() {
- return fileBuffer.get();
- }
- }
-
- private class ShortDeserializer implements Deserializer {
- @Override
- public Short deserialize() {
- return fileBuffer.getShort();
- }
- }
-
- private class IntDeserializer implements Deserializer {
- @Override
- public Integer deserialize() {
- return fileBuffer.getInt();
- }
- }
-
- private class LongDeserializer implements Deserializer {
- @Override
- public Long deserialize() {
- return fileBuffer.getLong();
- }
- }
-
- private class FloatDeserializer implements Deserializer {
- @Override
- public Float deserialize() {
- return fileBuffer.getFloat();
- }
- }
-
- private class DoubleDeserializer implements Deserializer {
- @Override
- public Double deserialize() {
- return fileBuffer.getDouble();
- }
- }
-
- private class StringDeserializer implements Deserializer {
- private int size;
-
- @Override
- public String deserialize() {
- size = fileBuffer.getInt();
- byte[] buffer = new byte[size];
- fileBuffer.get(buffer);
- return new String(buffer);
- }
- }
-
- private class NullDeserializer implements Deserializer {
- @Override
- public Object deserialize() {
- return null;
- }
- }
-
- private class BytesDeserializer implements Deserializer {
+ private class ByteArrayDeserializer implements Deserializer {
@Override
public byte[] deserialize() {
- int length = fileBuffer.getInt();
- byte[] result = new byte[length];
- fileBuffer.get(result);
- return result;
- }
-
- }
-
- private class TupleDeserializer implements Deserializer {
- Deserializer>[] deserializer = null;
- Tuple reuse;
-
- public TupleDeserializer() {
int size = fileBuffer.getInt();
- reuse = createTuple(size);
- deserializer = new Deserializer[size];
- for (int x = 0; x < deserializer.length; x++) {
- deserializer[x] = getDeserializer(fileBuffer.get());
- }
+ byte[] value = new byte[size];
+ fileBuffer.get(value);
+ return value;
}
+ }
+ private class TupleDeserializer implements Deserializer> {
@Override
- public Tuple deserialize() {
- for (int x = 0; x < deserializer.length; x++) {
- reuse.setField(deserializer[x].deserialize(), x);
+ public Tuple2 deserialize() {
+ int keyTupleSize = fileBuffer.get();
+ Tuple keys = createTuple(keyTupleSize);
+ for (int x = 0; x < keyTupleSize; x++) {
+ byte[] data = new byte[fileBuffer.getInt()];
+ fileBuffer.get(data);
+ keys.setField(data, x);
}
- return reuse;
+ byte[] value = new byte[fileBuffer.getInt()];
+ fileBuffer.get(value);
+ return new Tuple2(keys, value);
}
}
public static Tuple createTuple(int size) {
try {
return Tuple.getTupleClass(size).newInstance();
- } catch (InstantiationException e) {
- throw new RuntimeException(e);
- } catch (IllegalAccessException e) {
+ } catch (InstantiationException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonSender.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonSender.java
index 1d17243c0ac7e..3cd5f4db16d36 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonSender.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonSender.java
@@ -20,30 +20,19 @@
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
-import java.util.ArrayList;
import java.util.Iterator;
-import java.util.List;
import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple2;
import static org.apache.flink.python.api.PythonPlanBinder.FLINK_TMP_DATA_DIR;
import static org.apache.flink.python.api.PythonPlanBinder.MAPPED_FILE_SIZE;
-import org.apache.flink.python.api.types.CustomTypeWrapper;
/**
* General-purpose class to write data to memory-mapped files.
*/
-public class PythonSender implements Serializable {
- public static final byte TYPE_TUPLE = (byte) 11;
- public static final byte TYPE_BOOLEAN = (byte) 10;
- public static final byte TYPE_BYTE = (byte) 9;
- public static final byte TYPE_SHORT = (byte) 8;
- public static final byte TYPE_INTEGER = (byte) 7;
- public static final byte TYPE_LONG = (byte) 6;
- public static final byte TYPE_DOUBLE = (byte) 4;
- public static final byte TYPE_FLOAT = (byte) 5;
- public static final byte TYPE_CHAR = (byte) 3;
- public static final byte TYPE_STRING = (byte) 2;
- public static final byte TYPE_BYTES = (byte) 1;
- public static final byte TYPE_NULL = (byte) 0;
+public class PythonSender implements Serializable {
+ public static final byte TYPE_ARRAY = (byte) 63;
+ public static final byte TYPE_KEY_VALUE = (byte) 62;
+ public static final byte TYPE_VALUE_VALUE = (byte) 61;
private File outputFile;
private RandomAccessFile outputRAF;
@@ -95,7 +84,7 @@ public void reset() {
fileBuffer.clear();
}
- //=====Serialization================================================================================================
+ //=====IO===========================================================================================================
/**
* Writes a single record to the memory-mapped file. This method does NOT take care of synchronization. The user
* must guarantee that the file may be written to before calling this method. This method essentially reserves the
@@ -173,75 +162,24 @@ public int sendBuffer(Iterator i, int group) throws IOException {
return size;
}
- private enum SupportedTypes {
- TUPLE, BOOLEAN, BYTE, BYTES, CHARACTER, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, OTHER, NULL, CUSTOMTYPEWRAPPER
- }
-
//=====Serializer===================================================================================================
- private Serializer getSerializer(Object value) throws IOException {
- String className = value.getClass().getSimpleName().toUpperCase();
- if (className.startsWith("TUPLE")) {
- className = "TUPLE";
+ private Serializer getSerializer(Object value) {
+ if (value instanceof byte[]) {
+ return new ArraySerializer();
}
- if (className.startsWith("BYTE[]")) {
- className = "BYTES";
+ if (((Tuple2) value).f0 instanceof byte[]) {
+ return new ValuePairSerializer();
}
- SupportedTypes type = SupportedTypes.valueOf(className);
- switch (type) {
- case TUPLE:
- fileBuffer.put(TYPE_TUPLE);
- fileBuffer.putInt(((Tuple) value).getArity());
- return new TupleSerializer((Tuple) value);
- case BOOLEAN:
- fileBuffer.put(TYPE_BOOLEAN);
- return new BooleanSerializer();
- case BYTE:
- fileBuffer.put(TYPE_BYTE);
- return new ByteSerializer();
- case BYTES:
- fileBuffer.put(TYPE_BYTES);
- return new BytesSerializer();
- case CHARACTER:
- fileBuffer.put(TYPE_CHAR);
- return new CharSerializer();
- case SHORT:
- fileBuffer.put(TYPE_SHORT);
- return new ShortSerializer();
- case INTEGER:
- fileBuffer.put(TYPE_INTEGER);
- return new IntSerializer();
- case LONG:
- fileBuffer.put(TYPE_LONG);
- return new LongSerializer();
- case STRING:
- fileBuffer.put(TYPE_STRING);
- return new StringSerializer();
- case FLOAT:
- fileBuffer.put(TYPE_FLOAT);
- return new FloatSerializer();
- case DOUBLE:
- fileBuffer.put(TYPE_DOUBLE);
- return new DoubleSerializer();
- case NULL:
- fileBuffer.put(TYPE_NULL);
- return new NullSerializer();
- case CUSTOMTYPEWRAPPER:
- fileBuffer.put(((CustomTypeWrapper) value).getType());
- return new CustomTypeSerializer();
- default:
- throw new IllegalArgumentException("Unknown Type encountered: " + type);
+ if (((Tuple2) value).f0 instanceof Tuple) {
+ return new KeyValuePairSerializer();
}
+ throw new IllegalArgumentException("This object can't be serialized: " + value.toString());
}
private abstract class Serializer {
protected ByteBuffer buffer;
- public Serializer(int capacity) {
- buffer = ByteBuffer.allocate(capacity);
- }
-
public ByteBuffer serialize(T value) {
- buffer.clear();
serializeInternal(value);
buffer.flip();
return buffer;
@@ -250,171 +188,39 @@ public ByteBuffer serialize(T value) {
public abstract void serializeInternal(T value);
}
- private class CustomTypeSerializer extends Serializer {
- public CustomTypeSerializer() {
- super(0);
- }
+ private class ArraySerializer extends Serializer {
@Override
- public void serializeInternal(CustomTypeWrapper value) {
- byte[] bytes = value.getData();
- buffer = ByteBuffer.wrap(bytes);
- buffer.position(bytes.length);
- }
- }
-
- private class ByteSerializer extends Serializer {
- public ByteSerializer() {
- super(1);
- }
-
- @Override
- public void serializeInternal(Byte value) {
+ public void serializeInternal(byte[] value) {
+ buffer = ByteBuffer.allocate(value.length + 1);
+ buffer.put(TYPE_ARRAY);
buffer.put(value);
}
}
- private class BooleanSerializer extends Serializer {
- public BooleanSerializer() {
- super(1);
- }
-
- @Override
- public void serializeInternal(Boolean value) {
- buffer.put(value ? (byte) 1 : (byte) 0);
- }
- }
-
- private class CharSerializer extends Serializer {
- public CharSerializer() {
- super(4);
- }
-
- @Override
- public void serializeInternal(Character value) {
- buffer.put((value + "").getBytes());
- }
- }
-
- private class ShortSerializer extends Serializer {
- public ShortSerializer() {
- super(2);
- }
-
- @Override
- public void serializeInternal(Short value) {
- buffer.putShort(value);
- }
- }
-
- private class IntSerializer extends Serializer {
- public IntSerializer() {
- super(4);
- }
-
- @Override
- public void serializeInternal(Integer value) {
- buffer.putInt(value);
- }
- }
-
- private class LongSerializer extends Serializer {
- public LongSerializer() {
- super(8);
- }
-
- @Override
- public void serializeInternal(Long value) {
- buffer.putLong(value);
- }
- }
-
- private class StringSerializer extends Serializer {
- public StringSerializer() {
- super(0);
- }
-
- @Override
- public void serializeInternal(String value) {
- byte[] bytes = value.getBytes();
- buffer = ByteBuffer.allocate(bytes.length + 4);
- buffer.putInt(bytes.length);
- buffer.put(bytes);
- }
- }
-
- private class FloatSerializer extends Serializer {
- public FloatSerializer() {
- super(4);
- }
-
- @Override
- public void serializeInternal(Float value) {
- buffer.putFloat(value);
- }
- }
-
- private class DoubleSerializer extends Serializer {
- public DoubleSerializer() {
- super(8);
- }
-
+ private class ValuePairSerializer extends Serializer> {
@Override
- public void serializeInternal(Double value) {
- buffer.putDouble(value);
+ public void serializeInternal(Tuple2 value) {
+ buffer = ByteBuffer.allocate(1 + value.f0.length + value.f1.length);
+ buffer.put(TYPE_VALUE_VALUE);
+ buffer.put(value.f0);
+ buffer.put(value.f1);
}
}
- private class NullSerializer extends Serializer {
- public NullSerializer() {
- super(0);
- }
-
- @Override
- public void serializeInternal(Object value) {
- }
- }
-
- private class BytesSerializer extends Serializer {
- public BytesSerializer() {
- super(0);
- }
-
- @Override
- public void serializeInternal(byte[] value) {
- buffer = ByteBuffer.allocate(4 + value.length);
- buffer.putInt(value.length);
- buffer.put(value);
- }
- }
-
- private class TupleSerializer extends Serializer {
- private final Serializer[] serializer;
- private final List buffers;
-
- public TupleSerializer(Tuple value) throws IOException {
- super(0);
- serializer = new Serializer[value.getArity()];
- buffers = new ArrayList();
- for (int x = 0; x < serializer.length; x++) {
- serializer[x] = getSerializer(value.getField(x));
- }
- }
-
+ private class KeyValuePairSerializer extends Serializer> {
@Override
- public void serializeInternal(Tuple value) {
- int length = 0;
- for (int x = 0; x < serializer.length; x++) {
- serializer[x].buffer.clear();
- serializer[x].serializeInternal(value.getField(x));
- length += serializer[x].buffer.position();
- buffers.add(serializer[x].buffer);
+ public void serializeInternal(Tuple2 value) {
+ int keySize = 0;
+ for (int x = 0; x < value.f0.getArity(); x++) {
+ keySize += ((byte[]) value.f0.getField(x)).length;
}
- buffer = ByteBuffer.allocate(length);
- for (ByteBuffer b : buffers) {
- b.flip();
- buffer.put(b);
+ buffer = ByteBuffer.allocate(5 + keySize + value.f1.length);
+ buffer.put(TYPE_KEY_VALUE);
+ buffer.put((byte) value.f0.getArity());
+ for (int x = 0; x < value.f0.getArity(); x++) {
+ buffer.put((byte[]) value.f0.getField(x));
}
- buffers.clear();
+ buffer.put(value.f1);
}
}
}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonStreamer.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonStreamer.java
index 1e369629c09fe..556805178618a 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonStreamer.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonStreamer.java
@@ -34,6 +34,8 @@
import static org.apache.flink.python.api.PythonPlanBinder.FLINK_TMP_DATA_DIR;
import static org.apache.flink.python.api.PythonPlanBinder.PLANBINDER_CONFIG_BCVAR_COUNT;
import static org.apache.flink.python.api.PythonPlanBinder.PLANBINDER_CONFIG_BCVAR_NAME_PREFIX;
+import org.apache.flink.python.api.streaming.util.SerializationUtils.IntSerializer;
+import org.apache.flink.python.api.streaming.util.SerializationUtils.StringSerializer;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -58,8 +60,6 @@ public class PythonStreamer implements Serializable {
private String inputFilePath;
private String outputFilePath;
- private final byte[] buffer = new byte[4];
-
private Process process;
private Thread shutdownThread;
protected ServerSocket server;
@@ -75,13 +75,13 @@ public class PythonStreamer implements Serializable {
protected final AbstractRichFunction function;
- public PythonStreamer(AbstractRichFunction function, int id) {
+ public PythonStreamer(AbstractRichFunction function, int id, boolean usesByteArray) {
this.id = id;
this.usePython3 = PythonPlanBinder.usePython3;
this.debug = DEBUG;
planArguments = PythonPlanBinder.arguments.toString();
sender = new PythonSender();
- receiver = new PythonReceiver();
+ receiver = new PythonReceiver(usesByteArray);
this.function = function;
}
@@ -167,9 +167,9 @@ public void run() {
*/
public void close() throws IOException {
try {
- socket.close();
- sender.close();
- receiver.close();
+ socket.close();
+ sender.close();
+ receiver.close();
} catch (Exception e) {
LOG.error("Exception occurred while closing Streamer. :" + e.getMessage());
}
@@ -202,31 +202,18 @@ private void destroyProcess() throws IOException {
}
}
}
-
- private void sendWriteNotification(int size, boolean hasNext) throws IOException {
- byte[] tmp = new byte[5];
- putInt(tmp, 0, size);
- tmp[4] = hasNext ? 0 : SIGNAL_LAST;
- out.write(tmp, 0, 5);
+
+ private void sendWriteNotification(int size, boolean hasNext) throws IOException {
+ out.writeInt(size);
+ out.writeByte(hasNext ? 0 : SIGNAL_LAST);
out.flush();
}
private void sendReadConfirmation() throws IOException {
- out.write(new byte[1], 0, 1);
+ out.writeByte(1);
out.flush();
}
- private void checkForError() {
- if (getInt(buffer, 0) == -2) {
- try { //wait before terminating to ensure that the complete error message is printed
- Thread.sleep(2000);
- } catch (InterruptedException ex) {
- }
- throw new RuntimeException(
- "External process for task " + function.getRuntimeContext().getTaskName() + " terminated prematurely." + msg);
- }
- }
-
/**
* Sends all broadcast-variables encoded in the configuration to the external process.
*
@@ -243,26 +230,19 @@ public final void sendBroadCastVariables(Configuration config) throws IOExceptio
names[x] = config.getString(PLANBINDER_CONFIG_BCVAR_NAME_PREFIX + x, null);
}
- in.readFully(buffer, 0, 4);
- checkForError();
- int size = sender.sendRecord(broadcastCount);
- sendWriteNotification(size, false);
+ out.write(new IntSerializer().serializeWithoutTypeInfo(broadcastCount));
+ StringSerializer stringSerializer = new StringSerializer();
for (String name : names) {
Iterator bcv = function.getRuntimeContext().getBroadcastVariable(name).iterator();
- in.readFully(buffer, 0, 4);
- checkForError();
- size = sender.sendRecord(name);
- sendWriteNotification(size, false);
+ out.write(stringSerializer.serializeWithoutTypeInfo(name));
- while (bcv.hasNext() || sender.hasRemaining(0)) {
- in.readFully(buffer, 0, 4);
- checkForError();
- size = sender.sendBuffer(bcv, 0);
- sendWriteNotification(size, bcv.hasNext() || sender.hasRemaining(0));
+ while (bcv.hasNext()) {
+ out.writeByte(1);
+ out.write((byte[]) bcv.next());
}
- sender.reset();
+ out.writeByte(0);
}
} catch (SocketTimeoutException ste) {
throw new RuntimeException("External process for task " + function.getRuntimeContext().getTaskName() + " stopped responding." + msg);
@@ -281,8 +261,7 @@ public final void streamBufferWithoutGroups(Iterator i, Collector c) throws IOEx
int size;
if (i.hasNext()) {
while (true) {
- in.readFully(buffer, 0, 4);
- int sig = getInt(buffer, 0);
+ int sig = in.readInt();
switch (sig) {
case SIGNAL_BUFFER_REQUEST:
if (i.hasNext() || sender.hasRemaining(0)) {
@@ -326,8 +305,7 @@ public final void streamBufferWithGroups(Iterator i1, Iterator i2, Collector c)
int size;
if (i1.hasNext() || i2.hasNext()) {
while (true) {
- in.readFully(buffer, 0, 4);
- int sig = getInt(buffer, 0);
+ int sig = in.readInt();
switch (sig) {
case SIGNAL_BUFFER_REQUEST_G0:
if (i1.hasNext() || sender.hasRemaining(0)) {
@@ -361,16 +339,4 @@ public final void streamBufferWithGroups(Iterator i1, Iterator i2, Collector c)
throw new RuntimeException("External process for task " + function.getRuntimeContext().getTaskName() + " stopped responding." + msg);
}
}
-
- protected final static int getInt(byte[] array, int offset) {
- return (array[offset] << 24) | (array[offset + 1] & 0xff) << 16 | (array[offset + 2] & 0xff) << 8 | (array[offset + 3] & 0xff);
- }
-
- protected final static void putInt(byte[] array, int offset, int value) {
- array[offset] = (byte) (value >> 24);
- array[offset + 1] = (byte) (value >> 16);
- array[offset + 2] = (byte) (value >> 8);
- array[offset + 3] = (byte) (value);
- }
-
}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanReceiver.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanReceiver.java
index ed02ce4e4468d..5095e3bd32b3a 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanReceiver.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanReceiver.java
@@ -18,17 +18,15 @@
import java.io.Serializable;
import org.apache.flink.api.java.tuple.Tuple;
import static org.apache.flink.python.api.streaming.data.PythonReceiver.createTuple;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BOOLEAN;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTE;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTES;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_DOUBLE;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_FLOAT;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_INTEGER;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_LONG;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_NULL;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_SHORT;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_STRING;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_TUPLE;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_BOOLEAN;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_BYTE;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_BYTES;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_DOUBLE;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_FLOAT;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_INTEGER;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_LONG;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_NULL;
+import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_STRING;
import org.apache.flink.python.api.types.CustomTypeWrapper;
/**
@@ -46,62 +44,157 @@ public Object getRecord() throws IOException {
}
public Object getRecord(boolean normalized) throws IOException {
- return receiveField(normalized);
+ return getDeserializer().deserialize(normalized);
}
- private Object receiveField(boolean normalized) throws IOException {
+ private Deserializer getDeserializer() throws IOException {
byte type = (byte) input.readByte();
- switch (type) {
- case TYPE_TUPLE:
- int tupleSize = input.readByte();
- Tuple tuple = createTuple(tupleSize);
- for (int x = 0; x < tupleSize; x++) {
- tuple.setField(receiveField(normalized), x);
+ if (type > 0 && type < 26) {
+ Deserializer[] d = new Deserializer[type];
+ for (int x = 0; x < d.length; x++) {
+ d[x] = getDeserializer();
}
- return tuple;
+ return new TupleDeserializer(d);
+ }
+ switch (type) {
case TYPE_BOOLEAN:
- return input.readByte() == 1;
+ return new BooleanDeserializer();
case TYPE_BYTE:
- return (byte) input.readByte();
- case TYPE_SHORT:
- if (normalized) {
- return (int) input.readShort();
- } else {
- return input.readShort();
- }
+ return new ByteDeserializer();
case TYPE_INTEGER:
- return input.readInt();
+ return new IntDeserializer();
case TYPE_LONG:
- if (normalized) {
- return new Long(input.readLong()).intValue();
- } else {
- return input.readLong();
- }
+ return new LongDeserializer();
case TYPE_FLOAT:
- if (normalized) {
- return (double) input.readFloat();
- } else {
- return input.readFloat();
- }
+ return new FloatDeserializer();
case TYPE_DOUBLE:
- return input.readDouble();
+ return new DoubleDeserializer();
case TYPE_STRING:
- int stringSize = input.readInt();
- byte[] string = new byte[stringSize];
- input.readFully(string);
- return new String(string);
+ return new StringDeserializer();
case TYPE_BYTES:
- int bytessize = input.readInt();
- byte[] bytes = new byte[bytessize];
- input.readFully(bytes);
- return bytes;
+ return new BytesDeserializer();
case TYPE_NULL:
- return null;
+ return new NullDeserializer();
default:
- int size = input.readInt();
- byte[] data = new byte[size];
- input.readFully(data);
- return new CustomTypeWrapper(type, data);
+ return new CustomTypeDeserializer(type);
+ }
+ }
+
+ private abstract class Deserializer {
+ public T deserialize() throws IOException {
+ return deserialize(false);
+ }
+
+ public abstract T deserialize(boolean normalized) throws IOException;
+ }
+
+ private class TupleDeserializer extends Deserializer {
+ Deserializer[] deserializer;
+
+ public TupleDeserializer(Deserializer[] deserializer) {
+ this.deserializer = deserializer;
+ }
+
+ @Override
+ public Tuple deserialize(boolean normalized) throws IOException {
+ Tuple result = createTuple(deserializer.length);
+ for (int x = 0; x < result.getArity(); x++) {
+ result.setField(deserializer[x].deserialize(normalized), x);
+ }
+ return result;
+ }
+ }
+
+ private class CustomTypeDeserializer extends Deserializer {
+ private final byte type;
+
+ public CustomTypeDeserializer(byte type) {
+ this.type = type;
+ }
+
+ @Override
+ public CustomTypeWrapper deserialize(boolean normalized) throws IOException {
+ int size = input.readInt();
+ byte[] data = new byte[size];
+ input.read(data);
+ return new CustomTypeWrapper(type, data);
+ }
+ }
+
+ private class BooleanDeserializer extends Deserializer {
+ @Override
+ public Boolean deserialize(boolean normalized) throws IOException {
+ return input.readBoolean();
+ }
+ }
+
+ private class ByteDeserializer extends Deserializer {
+ @Override
+ public Byte deserialize(boolean normalized) throws IOException {
+ return input.readByte();
+ }
+ }
+
+ private class IntDeserializer extends Deserializer {
+ @Override
+ public Integer deserialize(boolean normalized) throws IOException {
+ return input.readInt();
+ }
+ }
+
+ private class LongDeserializer extends Deserializer {
+ @Override
+ public Object deserialize(boolean normalized) throws IOException {
+ if (normalized) {
+ return new Long(input.readLong()).intValue();
+ } else {
+ return input.readLong();
+ }
+ }
+ }
+
+ private class FloatDeserializer extends Deserializer {
+ @Override
+ public Object deserialize(boolean normalized) throws IOException {
+ if (normalized) {
+ return (double) input.readFloat();
+ } else {
+ return input.readFloat();
+ }
+ }
+ }
+
+ private class DoubleDeserializer extends Deserializer {
+ @Override
+ public Double deserialize(boolean normalized) throws IOException {
+ return input.readDouble();
+ }
+ }
+
+ private class StringDeserializer extends Deserializer {
+ @Override
+ public String deserialize(boolean normalized) throws IOException {
+ int size = input.readInt();
+ byte[] buffer = new byte[size];
+ input.read(buffer);
+ return new String(buffer);
+ }
+ }
+
+ private class NullDeserializer extends Deserializer {
+ @Override
+ public Object deserialize(boolean normalized) throws IOException {
+ return null;
+ }
+ }
+
+ private class BytesDeserializer extends Deserializer {
+ @Override
+ public byte[] deserialize(boolean normalized) throws IOException {
+ int size = input.readInt();
+ byte[] buffer = new byte[size];
+ input.read(buffer);
+ return buffer;
}
}
}
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanSender.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanSender.java
index 16a1eba8d5a09..7b6b68ad28bd0 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanSender.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanSender.java
@@ -16,19 +16,7 @@
import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
-import org.apache.flink.api.java.tuple.Tuple;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BOOLEAN;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTE;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTES;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_DOUBLE;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_FLOAT;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_INTEGER;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_LONG;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_NULL;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_SHORT;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_STRING;
-import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_TUPLE;
-import org.apache.flink.python.api.types.CustomTypeWrapper;
+import org.apache.flink.python.api.streaming.util.SerializationUtils;
/**
* Instances of this class can be used to send data to the plan process.
@@ -41,76 +29,7 @@ public PythonPlanSender(OutputStream output) {
}
public void sendRecord(Object record) throws IOException {
- String className = record.getClass().getSimpleName().toUpperCase();
- if (className.startsWith("TUPLE")) {
- className = "TUPLE";
- }
- if (className.startsWith("BYTE[]")) {
- className = "BYTES";
- }
- SupportedTypes type = SupportedTypes.valueOf(className);
- switch (type) {
- case TUPLE:
- output.write(TYPE_TUPLE);
- int arity = ((Tuple) record).getArity();
- output.writeInt(arity);
- for (int x = 0; x < arity; x++) {
- sendRecord(((Tuple) record).getField(x));
- }
- return;
- case BOOLEAN:
- output.write(TYPE_BOOLEAN);
- output.write(((Boolean) record) ? (byte) 1 : (byte) 0);
- return;
- case BYTE:
- output.write(TYPE_BYTE);
- output.write((Byte) record);
- return;
- case BYTES:
- output.write(TYPE_BYTES);
- output.write((byte[]) record, 0, ((byte[]) record).length);
- return;
- case CHARACTER:
- output.write(TYPE_STRING);
- output.writeChars(((Character) record) + "");
- return;
- case SHORT:
- output.write(TYPE_SHORT);
- output.writeShort((Short) record);
- return;
- case INTEGER:
- output.write(TYPE_INTEGER);
- output.writeInt((Integer) record);
- return;
- case LONG:
- output.write(TYPE_LONG);
- output.writeLong((Long) record);
- return;
- case STRING:
- output.write(TYPE_STRING);
- output.writeBytes((String) record);
- return;
- case FLOAT:
- output.write(TYPE_FLOAT);
- output.writeFloat((Float) record);
- return;
- case DOUBLE:
- output.write(TYPE_DOUBLE);
- output.writeDouble((Double) record);
- return;
- case NULL:
- output.write(TYPE_NULL);
- return;
- case CUSTOMTYPEWRAPPER:
- output.write(((CustomTypeWrapper) record).getType());
- output.write(((CustomTypeWrapper) record).getData());
- return;
- default:
- throw new IllegalArgumentException("Unknown Type encountered: " + type);
- }
- }
-
- private enum SupportedTypes {
- TUPLE, BOOLEAN, BYTE, BYTES, CHARACTER, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, OTHER, NULL, CUSTOMTYPEWRAPPER
+ byte[] data = SerializationUtils.getSerializer(record).serialize(record);
+ output.write(data);
}
}
\ No newline at end of file
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/util/SerializationUtils.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/util/SerializationUtils.java
new file mode 100644
index 0000000000000..6c83a61c7b7f3
--- /dev/null
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/util/SerializationUtils.java
@@ -0,0 +1,283 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
+ * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.apache.flink.python.api.streaming.util;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.python.api.types.CustomTypeWrapper;
+
+public class SerializationUtils {
+ public static final byte TYPE_BOOLEAN = (byte) 34;
+ public static final byte TYPE_BYTE = (byte) 33;
+ public static final byte TYPE_INTEGER = (byte) 32;
+ public static final byte TYPE_LONG = (byte) 31;
+ public static final byte TYPE_DOUBLE = (byte) 30;
+ public static final byte TYPE_FLOAT = (byte) 29;
+ public static final byte TYPE_STRING = (byte) 28;
+ public static final byte TYPE_BYTES = (byte) 27;
+ public static final byte TYPE_NULL = (byte) 26;
+
+ private enum SupportedTypes {
+ TUPLE, BOOLEAN, BYTE, BYTES, INTEGER, LONG, FLOAT, DOUBLE, STRING, NULL, CUSTOMTYPEWRAPPER
+ }
+
+ public static Serializer getSerializer(Object value) {
+ String className = value.getClass().getSimpleName().toUpperCase();
+ if (className.startsWith("TUPLE")) {
+ className = "TUPLE";
+ }
+ if (className.startsWith("BYTE[]")) {
+ className = "BYTES";
+ }
+ SupportedTypes type = SupportedTypes.valueOf(className);
+ switch (type) {
+ case TUPLE:
+ return new TupleSerializer((Tuple) value);
+ case BOOLEAN:
+ return new BooleanSerializer();
+ case BYTE:
+ return new ByteSerializer();
+ case BYTES:
+ return new BytesSerializer();
+ case INTEGER:
+ return new IntSerializer();
+ case LONG:
+ return new LongSerializer();
+ case STRING:
+ return new StringSerializer();
+ case FLOAT:
+ return new FloatSerializer();
+ case DOUBLE:
+ return new DoubleSerializer();
+ case NULL:
+ return new NullSerializer();
+ case CUSTOMTYPEWRAPPER:
+ return new CustomTypeWrapperSerializer((CustomTypeWrapper) value);
+ default:
+ throw new IllegalArgumentException("Unsupported Type encountered: " + type);
+ }
+ }
+
+ public static abstract class Serializer {
+ private byte[] typeInfo = null;
+
+ public byte[] serialize(IN value) {
+ if (typeInfo == null) {
+ typeInfo = new byte[getTypeInfoSize()];
+ ByteBuffer typeBuffer = ByteBuffer.wrap(typeInfo);
+ putTypeInfo(typeBuffer);
+ }
+ byte[] bytes = serializeWithoutTypeInfo(value);
+ byte[] total = new byte[typeInfo.length + bytes.length];
+ ByteBuffer.wrap(total).put(typeInfo).put(bytes);
+ return total;
+ }
+
+ public abstract byte[] serializeWithoutTypeInfo(IN value);
+
+ protected abstract void putTypeInfo(ByteBuffer buffer);
+
+ protected int getTypeInfoSize() {
+ return 1;
+ }
+ }
+
+ public static class CustomTypeWrapperSerializer extends Serializer {
+ private final byte type;
+
+ public CustomTypeWrapperSerializer(CustomTypeWrapper value) {
+ this.type = value.getType();
+ }
+
+ @Override
+ public byte[] serializeWithoutTypeInfo(CustomTypeWrapper value) {
+ byte[] result = new byte[4 + value.getData().length];
+ ByteBuffer.wrap(result).putInt(value.getData().length).put(value.getData());
+ return result;
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(type);
+ }
+ }
+
+ public static class ByteSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(Byte value) {
+ return new byte[]{value};
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_BYTE);
+ }
+ }
+
+ public static class BooleanSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(Boolean value) {
+ return new byte[]{value ? (byte) 1 : (byte) 0};
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_BOOLEAN);
+ }
+ }
+
+ public static class IntSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(Integer value) {
+ byte[] data = new byte[4];
+ ByteBuffer.wrap(data).putInt(value);
+ return data;
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_INTEGER);
+ }
+ }
+
+ public static class LongSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(Long value) {
+ byte[] data = new byte[8];
+ ByteBuffer.wrap(data).putLong(value);
+ return data;
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_LONG);
+ }
+ }
+
+ public static class StringSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(String value) {
+ byte[] string = value.getBytes();
+ byte[] data = new byte[4 + string.length];
+ ByteBuffer.wrap(data).putInt(string.length).put(string);
+ return data;
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_STRING);
+ }
+ }
+
+ public static class FloatSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(Float value) {
+ byte[] data = new byte[4];
+ ByteBuffer.wrap(data).putFloat(value);
+ return data;
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_FLOAT);
+ }
+ }
+
+ public static class DoubleSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(Double value) {
+ byte[] data = new byte[8];
+ ByteBuffer.wrap(data).putDouble(value);
+ return data;
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_DOUBLE);
+ }
+ }
+
+ public static class NullSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(Object value) {
+ return new byte[0];
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_NULL);
+ }
+ }
+
+ public static class BytesSerializer extends Serializer {
+ @Override
+ public byte[] serializeWithoutTypeInfo(byte[] value) {
+ byte[] data = new byte[4 + value.length];
+ ByteBuffer.wrap(data).putInt(value.length).put(value);
+ return data;
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put(TYPE_BYTES);
+ }
+ }
+
+ public static class TupleSerializer extends Serializer {
+ private final Serializer[] serializer;
+
+ public TupleSerializer(Tuple value) {
+ serializer = new Serializer[value.getArity()];
+ for (int x = 0; x < serializer.length; x++) {
+ serializer[x] = getSerializer(value.getField(x));
+ }
+ }
+
+ @Override
+ public byte[] serializeWithoutTypeInfo(Tuple value) {
+ ArrayList bits = new ArrayList();
+
+ int totalSize = 0;
+ for (int x = 0; x < serializer.length; x++) {
+ byte[] bit = serializer[x].serializeWithoutTypeInfo(value.getField(x));
+ bits.add(bit);
+ totalSize += bit.length;
+ }
+ int pointer = 0;
+ byte[] data = new byte[totalSize];
+ for (byte[] bit : bits) {
+ System.arraycopy(bit, 0, data, pointer, bit.length);
+ pointer += bit.length;
+ }
+ return data;
+ }
+
+ @Override
+ public void putTypeInfo(ByteBuffer buffer) {
+ buffer.put((byte) serializer.length);
+ for (Serializer s : serializer) {
+ s.putTypeInfo(buffer);
+ }
+ }
+
+ @Override
+ public int getTypeInfoSize() {
+ int size = 1;
+ for (Serializer s : serializer) {
+ size += s.getTypeInfoSize();
+ }
+ return size;
+ }
+ }
+}
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py
index b5674b98cb7f8..4c0db328ccb25 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py
@@ -16,11 +16,11 @@
# limitations under the License.
################################################################################
from struct import pack
-import sys
-
from flink.connection.Constants import Types
-from flink.plan.Constants import _Dummy
+
+#=====Compatibility====================================================================================================
+import sys
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
@@ -30,57 +30,128 @@
stringtype = str
+#=====Collector========================================================================================================
class Collector(object):
- def __init__(self, con, env):
+ def __init__(self, con, env, info):
self._connection = con
self._serializer = None
self._env = env
+ self._as_array = isinstance(info.types, bytearray)
def _close(self):
self._connection.send_end_signal()
def collect(self, value):
- self._serializer = _get_serializer(self._connection.write, value, self._env._types)
+ self._serializer = ArraySerializer(value, self._env._types) if self._as_array else KeyValuePairSerializer(value, self._env._types)
self.collect = self._collect
self.collect(value)
def _collect(self, value):
- self._connection.write(self._serializer.serialize(value))
+ serialized_value = self._serializer.serialize(value)
+ self._connection.write(serialized_value)
+
+
+class PlanCollector(object):
+ def __init__(self, con, env):
+ self._connection = con
+ self._env = env
+
+ def _close(self):
+ self._connection.send_end_signal()
+
+ def collect(self, value):
+ type = _get_type_info(value, self._env._types)
+ serializer = _get_serializer(value, self._env._types)
+ self._connection.write(b"".join([type, serializer.serialize(value)]))
+
+
+#=====Serializer=======================================================================================================
+class Serializer(object):
+ def serialize(self, value):
+ pass
+
+
+class KeyValuePairSerializer(Serializer):
+ def __init__(self, value, custom_types):
+ self._typeK = [_get_type_info(key, custom_types) for key in value[0]]
+ self._typeV = _get_type_info(value[1], custom_types)
+ self._typeK_length = [len(type) for type in self._typeK]
+ self._typeV_length = len(self._typeV)
+ self._serializerK = [_get_serializer(key, custom_types) for key in value[0]]
+ self._serializerV = _get_serializer(value[1], custom_types)
+
+ def serialize(self, value):
+ bits = [pack(">i", len(value[0]))[3:4]]
+ for i in range(len(value[0])):
+ x = self._serializerK[i].serialize(value[0][i])
+ bits.append(pack(">i", len(x) + self._typeK_length[i]))
+ bits.append(self._typeK[i])
+ bits.append(x)
+ v = self._serializerV.serialize(value[1])
+ bits.append(pack(">i", len(v) + self._typeV_length))
+ bits.append(self._typeV)
+ bits.append(v)
+ return b"".join(bits)
+
+class ArraySerializer(Serializer):
+ def __init__(self, value, custom_types):
+ self._type = _get_type_info(value, custom_types)
+ self._type_length = len(self._type)
+ self._serializer = _get_serializer(value, custom_types)
-def _get_serializer(write, value, custom_types):
+ def serialize(self, value):
+ serialized_value = self._serializer.serialize(value)
+ return b"".join([pack(">i", len(serialized_value) + self._type_length), self._type, serialized_value])
+
+
+def _get_type_info(value, custom_types):
if isinstance(value, (list, tuple)):
- write(Types.TYPE_TUPLE)
- write(pack(">I", len(value)))
- return TupleSerializer(write, value, custom_types)
+ return b"".join([pack(">i", len(value))[3:4], b"".join([_get_type_info(field, custom_types) for field in value])])
+ elif value is None:
+ return Types.TYPE_NULL
+ elif isinstance(value, stringtype):
+ return Types.TYPE_STRING
+ elif isinstance(value, bool):
+ return Types.TYPE_BOOLEAN
+ elif isinstance(value, int) or PY2 and isinstance(value, long):
+ return Types.TYPE_LONG
+ elif isinstance(value, bytearray):
+ return Types.TYPE_BYTES
+ elif isinstance(value, float):
+ return Types.TYPE_DOUBLE
+ else:
+ for entry in custom_types:
+ if isinstance(value, entry[1]):
+ return entry[0]
+ raise Exception("Unsupported Type encountered.")
+
+
+def _get_serializer(value, custom_types):
+ if isinstance(value, (list, tuple)):
+ return TupleSerializer(value, custom_types)
elif value is None:
- write(Types.TYPE_NULL)
return NullSerializer()
elif isinstance(value, stringtype):
- write(Types.TYPE_STRING)
return StringSerializer()
elif isinstance(value, bool):
- write(Types.TYPE_BOOLEAN)
return BooleanSerializer()
elif isinstance(value, int) or PY2 and isinstance(value, long):
- write(Types.TYPE_LONG)
return LongSerializer()
elif isinstance(value, bytearray):
- write(Types.TYPE_BYTES)
return ByteArraySerializer()
elif isinstance(value, float):
- write(Types.TYPE_DOUBLE)
return FloatSerializer()
else:
for entry in custom_types:
if isinstance(value, entry[1]):
- write(entry[0])
- return CustomTypeSerializer(entry[2])
+ return CustomTypeSerializer(entry[0], entry[2])
raise Exception("Unsupported Type encountered.")
-class CustomTypeSerializer(object):
- def __init__(self, serializer):
+class CustomTypeSerializer(Serializer):
+ def __init__(self, id, serializer):
+ self._id = id
self._serializer = serializer
def serialize(self, value):
@@ -88,9 +159,9 @@ def serialize(self, value):
return b"".join([pack(">i",len(msg)), msg])
-class TupleSerializer(object):
- def __init__(self, write, value, custom_types):
- self.serializer = [_get_serializer(write, field, custom_types) for field in value]
+class TupleSerializer(Serializer):
+ def __init__(self, value, custom_types):
+ self.serializer = [_get_serializer(field, custom_types) for field in value]
def serialize(self, value):
bits = []
@@ -99,83 +170,33 @@ def serialize(self, value):
return b"".join(bits)
-class BooleanSerializer(object):
+class BooleanSerializer(Serializer):
def serialize(self, value):
return pack(">?", value)
-class FloatSerializer(object):
+class FloatSerializer(Serializer):
def serialize(self, value):
return pack(">d", value)
-class LongSerializer(object):
+class LongSerializer(Serializer):
def serialize(self, value):
return pack(">q", value)
-class ByteArraySerializer(object):
+class ByteArraySerializer(Serializer):
def serialize(self, value):
value = bytes(value)
return pack(">I", len(value)) + value
-class StringSerializer(object):
+class StringSerializer(Serializer):
def serialize(self, value):
value = value.encode("utf-8")
return pack(">I", len(value)) + value
-class NullSerializer(object):
+class NullSerializer(Serializer):
def serialize(self, value):
- return b""
-
-
-class TypedCollector(object):
- def __init__(self, con, env):
- self._connection = con
- self._env = env
-
- def collect(self, value):
- if not isinstance(value, (list, tuple)):
- self._send_field(value)
- else:
- self._connection.write(Types.TYPE_TUPLE)
- meta = pack(">I", len(value))
- self._connection.write(bytes([meta[3]]) if PY3 else meta[3])
- for field in value:
- self.collect(field)
-
- def _send_field(self, value):
- if value is None:
- self._connection.write(Types.TYPE_NULL)
- elif isinstance(value, stringtype):
- value = value.encode("utf-8")
- size = pack(">I", len(value))
- self._connection.write(b"".join([Types.TYPE_STRING, size, value]))
- elif isinstance(value, bytes):
- size = pack(">I", len(value))
- self._connection.write(b"".join([Types.TYPE_BYTES, size, value]))
- elif isinstance(value, bool):
- data = pack(">?", value)
- self._connection.write(b"".join([Types.TYPE_BOOLEAN, data]))
- elif isinstance(value, int) or PY2 and isinstance(value, long):
- data = pack(">q", value)
- self._connection.write(b"".join([Types.TYPE_LONG, data]))
- elif isinstance(value, float):
- data = pack(">d", value)
- self._connection.write(b"".join([Types.TYPE_DOUBLE, data]))
- elif isinstance(value, bytearray):
- value = bytes(value)
- size = pack(">I", len(value))
- self._connection.write(b"".join([Types.TYPE_BYTES, size, value]))
- elif isinstance(value, _Dummy):
- self._connection.write(pack(">i", 127)[3:])
- self._connection.write(pack(">i", 0))
- else:
- for entry in self._env._types:
- if isinstance(value, entry[1]):
- self._connection.write(entry[0])
- self._connection.write(CustomTypeSerializer(entry[2]).serialize(value))
- return
- raise Exception("Unsupported Type encountered.")
\ No newline at end of file
+ return b""
\ No newline at end of file
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Connection.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Connection.py
index 680f49529a21a..293f5e9327021 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Connection.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Connection.py
@@ -115,6 +115,8 @@ def read(self, des_size, ignored=None):
self._read_buffer()
old_offset = self._input_offset
self._input_offset += des_size
+ if self._input_offset > self._input_size:
+ raise Exception("BufferUnderFlowException")
return self._input[old_offset:self._input_offset]
def _read_buffer(self):
@@ -140,6 +142,12 @@ def reset(self):
self._input_offset = 0
self._input = b""
+ def read_secondary(self, des_size):
+ return recv_all(self._socket, des_size)
+
+ def write_secondary(self, data):
+ self._socket.send(data)
+
class TwinBufferingTCPMappedFileConnection(BufferingTCPMappedFileConnection):
def __init__(self, input_file, output_file, port):
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Constants.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Constants.py
index 0ca2232629529..01a16fa3880c4 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Constants.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Constants.py
@@ -18,14 +18,15 @@
class Types(object):
- TYPE_TUPLE = b'\x0B'
- TYPE_BOOLEAN = b'\x0A'
- TYPE_BYTE = b'\x09'
- TYPE_SHORT = b'\x08'
- TYPE_INTEGER = b'\x07'
- TYPE_LONG = b'\x06'
- TYPE_DOUBLE = b'\x04'
- TYPE_FLOAT = b'\x05'
- TYPE_STRING = b'\x02'
- TYPE_NULL = b'\x00'
- TYPE_BYTES = b'\x01'
+ TYPE_ARRAY = b'\x3F'
+ TYPE_KEY_VALUE = b'\x3E'
+ TYPE_VALUE_VALUE = b'\x3D'
+ TYPE_BOOLEAN = b'\x22'
+ TYPE_BYTE = b'\x21'
+ TYPE_INTEGER = b'\x20'
+ TYPE_LONG = b'\x1F'
+ TYPE_DOUBLE = b'\x1E'
+ TYPE_FLOAT = b'\x1D'
+ TYPE_STRING = b'\x1C'
+ TYPE_BYTES = b'\x1B'
+ TYPE_NULL = b'\x1A'
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py
index 3425cfa1fd953..25321949455db 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py
@@ -26,6 +26,7 @@
from flink.connection.Constants import Types
+#=====Iterator==========================================================================================================
class ListIterator(defIter.Iterator):
def __init__(self, values):
super(ListIterator, self).__init__()
@@ -76,7 +77,7 @@ def next(self):
else:
self.cur = None
self.empty = True
- return tmp
+ return tmp[1]
else:
raise StopIteration
@@ -93,7 +94,7 @@ def next_group(self):
self.empty = False
def _extract_keys(self, x):
- return [x[k] for k in self.keys]
+ return [x[0][k] for k in self.keys]
def _extract_keys_id(self, x):
return x
@@ -175,6 +176,7 @@ def __init__(self, con, env, group=0):
self._group = group
self._deserializer = None
self._env = env
+ self._size = 0
def __next__(self):
return self.next()
@@ -184,8 +186,35 @@ def _read(self, des_size):
def next(self):
if self.has_next():
+ custom_types = self._env._types
+ read = self._read
if self._deserializer is None:
- self._deserializer = _get_deserializer(self._group, self._connection.read, self._env._types)
+ type = read(1)
+ if type == Types.TYPE_ARRAY:
+ key_des = _get_deserializer(read, custom_types)
+ self._deserializer = ArrayDeserializer(key_des)
+ return key_des.deserialize(read)
+ elif type == Types.TYPE_KEY_VALUE:
+ size = ord(read(1))
+ key_des = []
+ keys = []
+ for _ in range(size):
+ new_d = _get_deserializer(read, custom_types)
+ key_des.append(new_d)
+ keys.append(new_d.deserialize(read))
+ val_des = _get_deserializer(read, custom_types)
+ val = val_des.deserialize(read)
+ self._deserializer = KeyValueDeserializer(key_des, val_des)
+ return (tuple(keys), val)
+ elif type == Types.TYPE_VALUE_VALUE:
+ des1 = _get_deserializer(read, custom_types)
+ field1 = des1.deserialize(read)
+ des2 = _get_deserializer(read, custom_types)
+ field2 = des2.deserialize(read)
+ self._deserializer = ValueValueDeserializer(des1, des2)
+ return (field1, field2)
+ else:
+ raise Exception("Invalid type ID encountered: " + str(ord(type)))
return self._deserializer.deserialize(self._read)
else:
raise StopIteration
@@ -197,6 +226,16 @@ def _reset(self):
self._deserializer = None
+class PlanIterator(object):
+ def __init__(self, con, env):
+ self._connection = con
+ self._env = env
+
+ def next(self):
+ deserializer = _get_deserializer(self._connection.read, self._env._types)
+ return deserializer.deserialize(self._connection.read)
+
+
class DummyIterator(Iterator):
def __init__(self):
super(Iterator, self).__init__()
@@ -211,12 +250,11 @@ def has_next(self):
return False
-def _get_deserializer(group, read, custom_types, type=None):
- if type is None:
- type = read(1, group)
- return _get_deserializer(group, read, custom_types, type)
- elif type == Types.TYPE_TUPLE:
- return TupleDeserializer(read, group, custom_types)
+#=====Deserializer======================================================================================================
+def _get_deserializer(read, custom_types):
+ type = read(1)
+ if 0 < ord(type) < 26:
+ return TupleDeserializer([_get_deserializer(read, custom_types) for _ in range(ord(type))])
elif type == Types.TYPE_BYTE:
return ByteDeserializer()
elif type == Types.TYPE_BYTES:
@@ -238,101 +276,125 @@ def _get_deserializer(group, read, custom_types, type=None):
else:
for entry in custom_types:
if type == entry[0]:
- return entry[3]
- raise Exception("Unable to find deserializer for type ID " + str(type))
+ return CustomTypeDeserializer(entry[3])
+ raise Exception("Unable to find deserializer for type ID " + str(ord(type)))
-class TupleDeserializer(object):
- def __init__(self, read, group, custom_types):
- size = unpack(">I", read(4, group))[0]
- self.deserializer = [_get_deserializer(group, read, custom_types) for _ in range(size)]
+class Deserializer(object):
+ def get_type_info_size(self):
+ return 1
def deserialize(self, read):
- return tuple([s.deserialize(read) for s in self.deserializer])
+ pass
+
+class ArrayDeserializer(Deserializer):
+ def __init__(self, deserializer):
+ self._deserializer = deserializer
+ self._d_skip = deserializer.get_type_info_size()
-class ByteDeserializer(object):
+ def deserialize(self, read):
+ read(1) #array type
+ read(self._d_skip)
+ return self._deserializer.deserialize(read)
+
+
+class KeyValueDeserializer(Deserializer):
+ def __init__(self, key_deserializer, value_deserializer):
+ self._key_deserializer = [(k, k.get_type_info_size()) for k in key_deserializer]
+ self._value_deserializer = value_deserializer
+ self._value_deserializer_skip = value_deserializer.get_type_info_size()
+
+ def deserialize(self, read):
+ fields = []
+ read(1) #key value type
+ read(1) #key count
+ for dk in self._key_deserializer:
+ read(dk[1])
+ fields.append(dk[0].deserialize(read))
+ dv = self._value_deserializer
+ read(self._value_deserializer_skip)
+ return (tuple(fields), dv.deserialize(read))
+
+
+class ValueValueDeserializer(Deserializer):
+ def __init__(self, d1, d2):
+ self._d1 = d1
+ self._d1_skip = self._d1.get_type_info_size()
+ self._d2 = d2
+ self._d2_skip = self._d2.get_type_info_size()
+
+ def deserialize(self, read):
+ read(1)
+ read(self._d1_skip)
+ f1 = self._d1.deserialize(read)
+ read(self._d2_skip)
+ f2 = self._d2.deserialize(read)
+ return (f1, f2)
+
+
+class CustomTypeDeserializer(Deserializer):
+ def __init__(self, deserializer):
+ self._deserializer = deserializer
+
+ def deserialize(self, read):
+ read(4) #discard binary size
+ return self._deserializer.deserialize(read)
+
+
+class TupleDeserializer(Deserializer):
+ def __init__(self, deserializer):
+ self._deserializer = deserializer
+
+ def get_type_info_size(self):
+ return 1 + sum([d.get_type_info_size() for d in self._deserializer])
+
+ def deserialize(self, read):
+ return tuple([s.deserialize(read) for s in self._deserializer])
+
+
+class ByteDeserializer(Deserializer):
def deserialize(self, read):
return unpack(">c", read(1))[0]
-class ByteArrayDeserializer(object):
+class ByteArrayDeserializer(Deserializer):
def deserialize(self, read):
size = unpack(">i", read(4))[0]
return bytearray(read(size)) if size else bytearray(b"")
-class BooleanDeserializer(object):
+class BooleanDeserializer(Deserializer):
def deserialize(self, read):
return unpack(">?", read(1))[0]
-class FloatDeserializer(object):
+class FloatDeserializer(Deserializer):
def deserialize(self, read):
return unpack(">f", read(4))[0]
-class DoubleDeserializer(object):
+class DoubleDeserializer(Deserializer):
def deserialize(self, read):
return unpack(">d", read(8))[0]
-class IntegerDeserializer(object):
+class IntegerDeserializer(Deserializer):
def deserialize(self, read):
return unpack(">i", read(4))[0]
-class LongDeserializer(object):
+class LongDeserializer(Deserializer):
def deserialize(self, read):
return unpack(">q", read(8))[0]
-class StringDeserializer(object):
+class StringDeserializer(Deserializer):
def deserialize(self, read):
length = unpack(">i", read(4))[0]
return read(length).decode("utf-8") if length else ""
-class NullDeserializer(object):
- def deserialize(self):
+class NullDeserializer(Deserializer):
+ def deserialize(self, read):
return None
-
-
-class TypedIterator(object):
- def __init__(self, con, env):
- self._connection = con
- self._env = env
-
- def next(self):
- read = self._connection.read
- type = read(1)
- if type == Types.TYPE_TUPLE:
- size = unpack(">i", read(4))[0]
- return tuple([self.next() for x in range(size)])
- elif type == Types.TYPE_BYTE:
- return unpack(">c", read(1))[0]
- elif type == Types.TYPE_BYTES:
- size = unpack(">i", read(4))[0]
- return bytearray(read(size)) if size else bytearray(b"")
- elif type == Types.TYPE_BOOLEAN:
- return unpack(">?", read(1))[0]
- elif type == Types.TYPE_FLOAT:
- return unpack(">f", read(4))[0]
- elif type == Types.TYPE_DOUBLE:
- return unpack(">d", read(8))[0]
- elif type == Types.TYPE_INTEGER:
- return unpack(">i", read(4))[0]
- elif type == Types.TYPE_LONG:
- return unpack(">q", read(8))[0]
- elif type == Types.TYPE_STRING:
- length = unpack(">i", read(4))[0]
- return read(length).decode("utf-8") if length else ""
- elif type == Types.TYPE_NULL:
- return None
- else:
- for entry in self._env._types:
- if type == entry[0]:
- return entry[3]()
- raise Exception("Unable to find deserializer for type ID " + str(type))
-
-
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery10.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery10.py
index cc9e7cf1963d9..c1079875d9bbd 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery10.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery10.py
@@ -76,7 +76,7 @@ def join(self, value1, value2):
STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING], '\n', '|') \
.project(0,5,6,8) \
.filter(LineitemFilter()) \
- .map(ComputeRevenue(), [INT, FLOAT])
+ .map(ComputeRevenue())
nation = env \
.read_csv(sys.argv[4], [INT, STRING, INT, STRING], '\n', '|') \
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery3.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery3.py
index 3eb72c9792f77..aaa4e55cf91c9 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery3.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery3.py
@@ -87,13 +87,13 @@ def join(self, value1, value2):
.join(order) \
.where(0) \
.equal_to(1) \
- .using(CustomerOrderJoin(),[INT, FLOAT, STRING, INT])
+ .using(CustomerOrderJoin())
result = customerWithOrder \
.join(lineitem) \
.where(0) \
.equal_to(0) \
- .using(CustomerOrderLineitemJoin(), [INT, FLOAT, STRING, INT]) \
+ .using(CustomerOrderLineitemJoin()) \
.group_by(0, 2, 3) \
.reduce(SumReducer())
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TriangleEnumeration.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TriangleEnumeration.py
index b1b3ef4325ab0..f13cc0418687c 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TriangleEnumeration.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TriangleEnumeration.py
@@ -16,7 +16,7 @@
# limitations under the License.
################################################################################
from flink.plan.Environment import get_environment
-from flink.plan.Constants import INT, Order
+from flink.plan.Constants import Order
from flink.functions.FlatMapFunction import FlatMapFunction
from flink.functions.GroupReduceFunction import GroupReduceFunction
from flink.functions.ReduceFunction import ReduceFunction
@@ -123,27 +123,27 @@ def join(self, value1, value2):
(1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 5), (3, 4), (3, 7), (3, 8), (5, 6), (7, 8))
edges_with_degrees = edges \
- .flat_map(EdgeDuplicator(), [INT, INT]) \
+ .flat_map(EdgeDuplicator()) \
.group_by(0) \
.sort_group(1, Order.ASCENDING) \
- .reduce_group(DegreeCounter(), [INT, INT, INT, INT]) \
+ .reduce_group(DegreeCounter()) \
.group_by(0, 2) \
.reduce(DegreeJoiner())
edges_by_degree = edges_with_degrees \
- .map(EdgeByDegreeProjector(), [INT, INT])
+ .map(EdgeByDegreeProjector())
edges_by_id = edges_by_degree \
- .map(EdgeByIdProjector(), [INT, INT])
+ .map(EdgeByIdProjector())
triangles = edges_by_degree \
.group_by(0) \
.sort_group(1, Order.ASCENDING) \
- .reduce_group(TriadBuilder(), [INT, INT, INT]) \
+ .reduce_group(TriadBuilder()) \
.join(edges_by_id) \
.where(1, 2) \
.equal_to(0, 1) \
- .using(TriadFilter(), [INT, INT, INT])
+ .using(TriadFilter())
triangles.output()
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WebLogAnalysis.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WebLogAnalysis.py
index 676043fdc873c..1ea3e78e1dbcf 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WebLogAnalysis.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WebLogAnalysis.py
@@ -19,7 +19,7 @@
from datetime import datetime
from flink.plan.Environment import get_environment
-from flink.plan.Constants import INT, STRING, FLOAT, WriteMode
+from flink.plan.Constants import WriteMode
from flink.functions.CoGroupFunction import CoGroupFunction
from flink.functions.FilterFunction import FilterFunction
@@ -54,16 +54,16 @@ def co_group(self, iterator1, iterator2, collector):
sys.exit("Usage: ./bin/pyflink.sh WebLogAnalysis ")
documents = env \
- .read_csv(sys.argv[1], [STRING, STRING], "\n", "|") \
+ .read_csv(sys.argv[1], "\n", "|") \
.filter(DocumentFilter()) \
.project(0)
ranks = env \
- .read_csv(sys.argv[2], [INT, STRING, INT], "\n", "|") \
+ .read_csv(sys.argv[2], "\n", "|") \
.filter(RankFilter())
visits = env \
- .read_csv(sys.argv[3], [STRING, STRING, STRING, FLOAT, STRING, STRING, STRING, STRING, INT], "\n", "|") \
+ .read_csv(sys.argv[3], "\n", "|") \
.project(1,2) \
.filter(VisitFilter()) \
.project(0)
@@ -78,7 +78,7 @@ def co_group(self, iterator1, iterator2, collector):
.co_group(visits) \
.where(1) \
.equal_to(0) \
- .using(AntiJoinVisits(), [INT, STRING, INT])
+ .using(AntiJoinVisits())
result.write_csv(sys.argv[4], '\n', '|', WriteMode.OVERWRITE)
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WordCount.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WordCount.py
index 71c2e28337183..2ab724c384b04 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WordCount.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WordCount.py
@@ -18,7 +18,6 @@
import sys
from flink.plan.Environment import get_environment
-from flink.plan.Constants import INT, STRING
from flink.functions.FlatMapFunction import FlatMapFunction
from flink.functions.GroupReduceFunction import GroupReduceFunction
@@ -47,9 +46,9 @@ def reduce(self, iterator, collector):
data = env.from_elements("hello","world","hello","car","tree","data","hello")
result = data \
- .flat_map(Tokenizer(), (INT, STRING)) \
+ .flat_map(Tokenizer()) \
.group_by(1) \
- .reduce_group(Adder(), (INT, STRING), combinable=True) \
+ .reduce_group(Adder(), combinable=True) \
if len(sys.argv) == 3:
result.write_csv(sys.argv[2])
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py
index 9c55787ae243a..4cb337ab4a2e2 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py
@@ -25,13 +25,16 @@ def __init__(self):
self._keys1 = None
self._keys2 = None
- def _configure(self, input_file, output_file, port, env):
+ def _configure(self, input_file, output_file, port, env, info):
self._connection = Connection.TwinBufferingTCPMappedFileConnection(input_file, output_file, port)
self._iterator = Iterator.Iterator(self._connection, env, 0)
self._iterator2 = Iterator.Iterator(self._connection, env, 1)
self._cgiter = Iterator.CoGroupIterator(self._iterator, self._iterator2, self._keys1, self._keys2)
+ self._collector = Collector.Collector(self._connection, env, info)
self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
- self._configure_chain(Collector.Collector(self._connection, env))
+ if info.chained_info is not None:
+ info.chained_info.operator._configure_chain(self.context, self._collector, info.chained_info)
+ self._collector = info.chained_info.operator
def _run(self):
collector = self._collector
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py
index f874a25261a80..dfe6a283acf4c 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py
@@ -16,9 +16,9 @@
# limitations under the License.
################################################################################
from abc import ABCMeta, abstractmethod
-import sys
from collections import deque
from flink.connection import Connection, Iterator, Collector
+from flink.connection.Iterator import IntegerDeserializer, StringDeserializer, _get_deserializer
from flink.functions import RuntimeContext
@@ -30,55 +30,56 @@ def __init__(self):
self._iterator = None
self._collector = None
self.context = None
- self._chain_operator = None
+ self._env = None
- def _configure(self, input_file, output_file, port, env):
+ def _configure(self, input_file, output_file, port, env, info):
self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port)
self._iterator = Iterator.Iterator(self._connection, env)
+ self._collector = Collector.Collector(self._connection, env, info)
self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
- self._configure_chain(Collector.Collector(self._connection, env))
+ self._env = env
+ if info.chained_info is not None:
+ info.chained_info.operator._configure_chain(self.context, self._collector, info.chained_info)
+ self._collector = info.chained_info.operator
- def _configure_chain(self, collector):
- if self._chain_operator is not None:
- self._collector = self._chain_operator
- self._collector.context = self.context
- self._collector._configure_chain(collector)
- self._collector._open()
- else:
+ def _configure_chain(self, context, collector, info):
+ self.context = context
+ if info.chained_info is None:
self._collector = collector
-
- def _chain(self, operator):
- self._chain_operator = operator
+ else:
+ self._collector = info.chained_info.operator
+ info.chained_info.operator._configure_chain(context, collector, info.chained_info)
@abstractmethod
def _run(self):
pass
- def _open(self):
- pass
-
def _close(self):
self._collector._close()
- self._connection.close()
+ if self._connection is not None:
+ self._connection.close()
def _go(self):
self._receive_broadcast_variables()
self._run()
def _receive_broadcast_variables(self):
- broadcast_count = self._iterator.next()
- self._iterator._reset()
- self._connection.reset()
+ con = self._connection
+ deserializer_int = IntegerDeserializer()
+ broadcast_count = deserializer_int.deserialize(con.read_secondary)
+ deserializer_string = StringDeserializer()
for _ in range(broadcast_count):
- name = self._iterator.next()
- self._iterator._reset()
- self._connection.reset()
+ name = deserializer_string.deserialize(con.read_secondary)
bc = deque()
- while(self._iterator.has_next()):
- bc.append(self._iterator.next())
+ if con.read_secondary(1) == b"\x01":
+ serializer_data = _get_deserializer(con.read_secondary, self._env._types)
+ value = serializer_data.deserialize(con.read_secondary)
+ bc.append(value)
+ while con.read_secondary(1) == b"\x01":
+ con.read_secondary(serializer_data.get_type_info_size()) #skip type info
+ value = serializer_data.deserialize(con.read_secondary)
+ bc.append(value)
self.context._add_broadcast_variable(name, bc)
- self._iterator._reset()
- self._connection.reset()
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py
index b758c19a1167b..340497da64c1e 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py
@@ -24,21 +24,14 @@
class GroupReduceFunction(Function.Function):
def __init__(self):
super(GroupReduceFunction, self).__init__()
- self._keys = None
- def _configure(self, input_file, output_file, port, env):
- self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port)
- self._iterator = Iterator.Iterator(self._connection, env)
- if self._keys is None:
+ def _configure(self, input_file, output_file, port, env, info):
+ super(GroupReduceFunction, self)._configure(input_file, output_file, port, env, info)
+ if info.key1 is None:
self._run = self._run_all_group_reduce
else:
self._run = self._run_grouped_group_reduce
- self._group_iterator = Iterator.GroupIterator(self._iterator, self._keys)
- self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
- self._collector = Collector.Collector(self._connection, env)
-
- def _set_grouping_keys(self, keys):
- self._keys = keys
+ self._group_iterator = Iterator.GroupIterator(self._iterator, info.key1)
def _run(self):
pass
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/KeySelectorFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/KeySelectorFunction.py
new file mode 100644
index 0000000000000..a75961f9341aa
--- /dev/null
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/KeySelectorFunction.py
@@ -0,0 +1,28 @@
+# ###############################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+
+class KeySelectorFunction:
+ def __call__(self, value):
+ return self.get_key(value)
+
+ def callable(self):
+ return True
+
+ def get_key(self, value):
+ pass
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py
index 45a22da75c0f7..95e8b8a1df69d 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py
@@ -23,21 +23,14 @@
class ReduceFunction(Function.Function):
def __init__(self):
super(ReduceFunction, self).__init__()
- self._keys = None
- def _configure(self, input_file, output_file, port, env):
- self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port)
- self._iterator = Iterator.Iterator(self._connection, env)
- if self._keys is None:
+ def _configure(self, input_file, output_file, port, env, info):
+ super(ReduceFunction, self)._configure(input_file, output_file, port, env, info)
+ if info.key1 is None:
self._run = self._run_all_reduce
else:
self._run = self._run_grouped_reduce
- self._group_iterator = Iterator.GroupIterator(self._iterator, self._keys)
- self._collector = Collector.Collector(self._connection, env)
- self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector)
-
- def _set_grouping_keys(self, keys):
- self._keys = keys
+ self._group_iterator = Iterator.GroupIterator(self._iterator, info.key1)
def _run(self):
pass
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py
index 0c9fe80efe8d4..888146348ddde 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py
@@ -63,11 +63,6 @@ class Order(object):
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
-
-class _Dummy(object):
- pass
-
-
if PY2:
BOOL = True
INT = 1
@@ -75,11 +70,17 @@ class _Dummy(object):
FLOAT = 2.5
STRING = "type"
BYTES = bytearray(b"byte")
- CUSTOM = _Dummy()
elif PY3:
BOOL = True
INT = 1
FLOAT = 2.5
STRING = "type"
BYTES = bytearray(b"byte")
- CUSTOM = _Dummy()
+
+
+def _createKeyValueTypeInfo(keyCount):
+ return (tuple([BYTES for _ in range(keyCount)]), BYTES)
+
+
+def _createArrayTypeInfo():
+ return BYTES
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/DataSet.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/DataSet.py
index 25ec8b8a6c3e1..eda8d025da9b1 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/DataSet.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/DataSet.py
@@ -15,11 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-import inspect
import copy
import types as TYPES
-from flink.plan.Constants import _Identifier, WriteMode, STRING
+from flink.plan.Constants import _Identifier, WriteMode, _createKeyValueTypeInfo, _createArrayTypeInfo
from flink.plan.OperationInfo import OperationInfo
from flink.functions.CoGroupFunction import CoGroupFunction
from flink.functions.FilterFunction import FilterFunction
@@ -30,52 +29,33 @@
from flink.functions.MapFunction import MapFunction
from flink.functions.MapPartitionFunction import MapPartitionFunction
from flink.functions.ReduceFunction import ReduceFunction
+from flink.functions.KeySelectorFunction import KeySelectorFunction
-def deduct_output_type(dataset):
- skip = set([_Identifier.GROUP, _Identifier.SORT, _Identifier.UNION])
- source = set([_Identifier.SOURCE_CSV, _Identifier.SOURCE_TEXT, _Identifier.SOURCE_VALUE])
- default = set([_Identifier.CROSS, _Identifier.CROSSH, _Identifier.CROSST, _Identifier.JOINT, _Identifier.JOINH, _Identifier.JOIN])
-
- while True:
- dataset_type = dataset.identifier
- if dataset_type in skip:
- dataset = dataset.parent
- continue
- if dataset_type in source:
- if dataset_type == _Identifier.SOURCE_TEXT:
- return STRING
- if dataset_type == _Identifier.SOURCE_VALUE:
- return dataset.values[0]
- if dataset_type == _Identifier.SOURCE_CSV:
- return dataset.types
- if dataset_type == _Identifier.PROJECTION:
- return tuple([deduct_output_type(dataset.parent)[k] for k in dataset.keys])
- if dataset_type in default:
- if dataset.operator is not None: #udf-join/cross
- return dataset.types
- if len(dataset.projections) == 0: #defaultjoin/-cross
- return (deduct_output_type(dataset.parent), deduct_output_type(dataset.other))
- else: #projectjoin/-cross
- t1 = deduct_output_type(dataset.parent)
- t2 = deduct_output_type(dataset.other)
- out_type = []
- for prj in dataset.projections:
- if len(prj[1]) == 0: #projection on non-tuple dataset
- if prj[0] == "first":
- out_type.append(t1)
- else:
- out_type.append(t2)
- else: #projection on tuple dataset
- for key in prj[1]:
- if prj[0] == "first":
- out_type.append(t1[key])
- else:
- out_type.append(t2[key])
- return tuple(out_type)
- return dataset.types
-
-
-class Set(object):
+
+class Stringify(MapFunction):
+ def map(self, value):
+ if isinstance(value, (tuple, list)):
+ return "(" + b", ".join([self.map(x) for x in value]) + ")"
+ else:
+ return str(value)
+
+
+class CsvStringify(MapFunction):
+ def __init__(self, f_delim):
+ super(CsvStringify, self).__init__()
+ self.delim = f_delim
+
+ def map(self, value):
+ return self.delim.join([self._map(field) for field in value])
+
+ def _map(self, value):
+ if isinstance(value, (tuple, list)):
+ return "(" + b", ".join([self.map(x) for x in value]) + ")"
+ else:
+ return str(value)
+
+
+class DataSet(object):
def __init__(self, env, info):
self._env = env
self._info = info
@@ -86,6 +66,9 @@ def output(self, to_error=False):
"""
Writes a DataSet to the standard output stream (stdout).
"""
+ self.map(Stringify())._output(to_error)
+
+ def _output(self, to_error):
child = OperationInfo()
child.identifier = _Identifier.SINK_PRINT
child.parent = self._info
@@ -100,6 +83,9 @@ def write_text(self, path, write_mode=WriteMode.NO_OVERWRITE):
:param path: he path pointing to the location the text file is written to.
:param write_mode: OutputFormat.WriteMode value, indicating whether files should be overwritten
"""
+ return self.map(Stringify())._write_text(path, write_mode)
+
+ def _write_text(self, path, write_mode):
child = OperationInfo()
child.identifier = _Identifier.SINK_TEXT
child.parent = self._info
@@ -116,6 +102,9 @@ def write_csv(self, path, line_delimiter="\n", field_delimiter=',', write_mode=W
:param path: The path pointing to the location the CSV file is written to.
:param write_mode: OutputFormat.WriteMode value, indicating whether files should be overwritten
"""
+ return self.map(CsvStringify(field_delimiter))._write_csv(path, line_delimiter, field_delimiter, write_mode)
+
+ def _write_csv(self, path, line_delimiter, field_delimiter, write_mode):
child = OperationInfo()
child.identifier = _Identifier.SINK_CSV
child.path = path
@@ -126,7 +115,7 @@ def write_csv(self, path, line_delimiter="\n", field_delimiter=',', write_mode=W
self._info.sinks.append(child)
self._env._sinks.append(child)
- def reduce_group(self, operator, types, combinable=False):
+ def reduce_group(self, operator, combinable=False):
"""
Applies a GroupReduce transformation.
@@ -136,7 +125,6 @@ def reduce_group(self, operator, types, combinable=False):
emit any number of output elements including none.
:param operator: The GroupReduceFunction that is applied on the DataSet.
- :param types: The type of the resulting DataSet.
:return:A GroupReduceOperator that represents the reduced DataSet.
"""
if isinstance(operator, TYPES.FunctionType):
@@ -148,17 +136,12 @@ def reduce_group(self, operator, types, combinable=False):
child.identifier = _Identifier.GROUPREDUCE
child.parent = self._info
child.operator = operator
- child.types = types
+ child.types = _createArrayTypeInfo()
child.name = "PythonGroupReduce"
self._info.children.append(child)
self._env._sets.append(child)
return child_set
-
-class ReduceSet(Set):
- def __init__(self, env, info):
- super(ReduceSet, self).__init__(env, info)
-
def reduce(self, operator):
"""
Applies a Reduce transformation on a non-grouped DataSet.
@@ -179,16 +162,11 @@ def reduce(self, operator):
child.parent = self._info
child.operator = operator
child.name = "PythonReduce"
- child.types = deduct_output_type(self._info)
+ child.types = _createArrayTypeInfo()
self._info.children.append(child)
self._env._sets.append(child)
return child_set
-
-class DataSet(ReduceSet):
- def __init__(self, env, info):
- super(DataSet, self).__init__(env, info)
-
def project(self, *fields):
"""
Applies a Project transformation on a Tuple DataSet.
@@ -201,14 +179,7 @@ def project(self, *fields):
:return: The projected DataSet.
"""
- child = OperationInfo()
- child_set = DataSet(self._env, child)
- child.identifier = _Identifier.PROJECTION
- child.parent = self._info
- child.keys = fields
- self._info.children.append(child)
- self._env._sets.append(child)
- return child_set
+ return self.map(lambda x: tuple([x[key] for key in fields]))
def group_by(self, *keys):
"""
@@ -223,6 +194,9 @@ def group_by(self, *keys):
:param keys: One or more field positions on which the DataSet will be grouped.
:return:A Grouping on which a transformation needs to be applied to obtain a transformed DataSet.
"""
+ return self.map(lambda x: x)._group_by(keys)
+
+ def _group_by(self, keys):
child = OperationInfo()
child_chain = []
child_set = UnsortedGrouping(self._env, child, child_chain)
@@ -251,9 +225,8 @@ def co_group(self, other_set):
other_set._info.children.append(child)
child_set = CoGroupOperatorWhere(self._env, child)
child.identifier = _Identifier.COGROUP
- child.parent = self._info
- child.other = other_set._info
- self._info.children.append(child)
+ child.parent_set = self
+ child.other_set = other_set
return child_set
def cross(self, other_set):
@@ -323,14 +296,13 @@ def filter(self, operator):
child.identifier = _Identifier.FILTER
child.parent = self._info
child.operator = operator
- child.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__)
child.name = "PythonFilter"
- child.types = deduct_output_type(self._info)
+ child.types = _createArrayTypeInfo()
self._info.children.append(child)
self._env._sets.append(child)
return child_set
- def flat_map(self, operator, types):
+ def flat_map(self, operator):
"""
Applies a FlatMap transformation on a DataSet.
@@ -338,7 +310,6 @@ def flat_map(self, operator, types):
Each FlatMapFunction call can return any number of elements including none.
:param operator: The FlatMapFunction that is called for each element of the DataSet.
- :param types: The type of the resulting DataSet.
:return:A FlatMapOperator that represents the transformed DataSe
"""
if isinstance(operator, TYPES.FunctionType):
@@ -350,8 +321,7 @@ def flat_map(self, operator, types):
child.identifier = _Identifier.FLATMAP
child.parent = self._info
child.operator = operator
- child.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__)
- child.types = types
+ child.types = _createArrayTypeInfo()
child.name = "PythonFlatMap"
self._info.children.append(child)
self._env._sets.append(child)
@@ -398,14 +368,11 @@ def _join(self, other_set, identifier):
child = OperationInfo()
child_set = JoinOperatorWhere(self._env, child)
child.identifier = identifier
- child.parent = self._info
- child.other = other_set._info
- self._info.children.append(child)
- other_set._info.children.append(child)
- self._env._sets.append(child)
+ child.parent_set = self
+ child.other_set = other_set
return child_set
- def map(self, operator, types):
+ def map(self, operator):
"""
Applies a Map transformation on a DataSet.
@@ -413,7 +380,6 @@ def map(self, operator, types):
Each MapFunction call returns exactly one element.
:param operator: The MapFunction that is called for each element of the DataSet.
- :param types: The type of the resulting DataSet
:return:A MapOperator that represents the transformed DataSet
"""
if isinstance(operator, TYPES.FunctionType):
@@ -425,14 +391,13 @@ def map(self, operator, types):
child.identifier = _Identifier.MAP
child.parent = self._info
child.operator = operator
- child.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__)
- child.types = types
+ child.types = _createArrayTypeInfo()
child.name = "PythonMap"
self._info.children.append(child)
self._env._sets.append(child)
return child_set
- def map_partition(self, operator, types):
+ def map_partition(self, operator):
"""
Applies a MapPartition transformation on a DataSet.
@@ -444,7 +409,6 @@ def map_partition(self, operator, types):
sees is non deterministic and depends on the degree of parallelism of the operation.
:param operator: The MapFunction that is called for each element of the DataSet.
- :param types: The type of the resulting DataSet
:return:A MapOperator that represents the transformed DataSet
"""
if isinstance(operator, TYPES.FunctionType):
@@ -456,8 +420,7 @@ def map_partition(self, operator, types):
child.identifier = _Identifier.MAPPARTITION
child.parent = self._info
child.operator = operator
- child.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__)
- child.types = types
+ child.types = _createArrayTypeInfo()
child.name = "PythonMapPartition"
self._info.children.append(child)
self._env._sets.append(child)
@@ -506,7 +469,10 @@ def __init__(self, env, info, child_chain):
info.id = env._counter
env._counter += 1
- def reduce_group(self, operator, types, combinable=False):
+ def _finalize(self):
+ pass
+
+ def reduce_group(self, operator, combinable=False):
"""
Applies a GroupReduce transformation.
@@ -516,21 +482,21 @@ def reduce_group(self, operator, types, combinable=False):
emit any number of output elements including none.
:param operator: The GroupReduceFunction that is applied on the DataSet.
- :param types: The type of the resulting DataSet.
:return:A GroupReduceOperator that represents the reduced DataSet.
"""
+ self._finalize()
if isinstance(operator, TYPES.FunctionType):
f = operator
operator = GroupReduceFunction()
operator.reduce = f
- operator._set_grouping_keys(self._child_chain[0].keys)
child = OperationInfo()
child_set = OperatorSet(self._env, child)
child.identifier = _Identifier.GROUPREDUCE
child.parent = self._info
child.operator = operator
- child.types = types
+ child.types = _createArrayTypeInfo()
child.name = "PythonGroupReduce"
+ child.key1 = self._child_chain[0].keys
self._info.children.append(child)
self._env._sets.append(child)
@@ -573,26 +539,96 @@ def reduce(self, operator):
:param operator:The ReduceFunction that is applied on the DataSet.
:return:A ReduceOperator that represents the reduced DataSet.
"""
- operator._set_grouping_keys(self._child_chain[0].keys)
- for i in self._child_chain:
- self._env._sets.append(i)
+ self._finalize()
+ if isinstance(operator, TYPES.FunctionType):
+ f = operator
+ operator = ReduceFunction()
+ operator.reduce = f
child = OperationInfo()
child_set = OperatorSet(self._env, child)
child.identifier = _Identifier.REDUCE
child.parent = self._info
child.operator = operator
child.name = "PythonReduce"
- child.types = deduct_output_type(self._info)
+ child.types = _createArrayTypeInfo()
+ child.key1 = self._child_chain[0].keys
self._info.children.append(child)
self._env._sets.append(child)
return child_set
+ def _finalize(self):
+ grouping = self._child_chain[0]
+ keys = grouping.keys
+ f = None
+ if isinstance(keys[0], TYPES.FunctionType):
+ f = lambda x: (keys[0](x),)
+ if isinstance(keys[0], KeySelectorFunction):
+ f = lambda x: (keys[0].get_key(x),)
+ if f is None:
+ f = lambda x: tuple([x[key] for key in keys])
+
+ grouping.parent.operator.map = lambda x: (f(x), x)
+ grouping.parent.types = _createKeyValueTypeInfo(len(keys))
+ grouping.keys = tuple([i for i in range(len(grouping.keys))])
+
class SortedGrouping(Grouping):
def __init__(self, env, info, child_chain):
super(SortedGrouping, self).__init__(env, info, child_chain)
+ def _finalize(self):
+ grouping = self._child_chain[0]
+ sortings = self._child_chain[1:]
+
+ #list of used index keys to prevent duplicates and determine final index
+ index_keys = set()
+
+ if not isinstance(grouping.keys[0], (TYPES.FunctionType, KeySelectorFunction)):
+ index_keys = index_keys.union(set(grouping.keys))
+
+ #list of sorts using indices
+ index_sorts = []
+ #list of sorts using functions
+ ksl_sorts = []
+ for s in sortings:
+ if not isinstance(s.field, (TYPES.FunctionType, KeySelectorFunction)):
+ index_keys.add(s.field)
+ index_sorts.append(s)
+ else:
+ ksl_sorts.append(s)
+
+ used_keys = sorted(index_keys)
+ #all data gathered
+
+ #construct list of extractor lambdas
+ lambdas = []
+ i = 0
+ for key in used_keys:
+ lambdas.append(lambda x, k=key: x[k])
+ i += 1
+ if isinstance(grouping.keys[0], (TYPES.FunctionType, KeySelectorFunction)):
+ lambdas.append(grouping.keys[0])
+ for ksl_op in ksl_sorts:
+ lambdas.append(ksl_op.field)
+
+ grouping.parent.operator.map = lambda x: (tuple([l(x) for l in lambdas]), x)
+ grouping.parent.types = _createKeyValueTypeInfo(len(lambdas))
+ #modify keys
+ ksl_offset = len(used_keys)
+ if not isinstance(grouping.keys[0], (TYPES.FunctionType, KeySelectorFunction)):
+ grouping.keys = tuple([used_keys.index(key) for key in grouping.keys])
+ else:
+ grouping.keys = (ksl_offset,)
+ ksl_offset += 1
+
+ for iop in index_sorts:
+ iop.field = used_keys.index(iop.field)
+
+ for kop in ksl_sorts:
+ kop.field = ksl_offset
+ ksl_offset += 1
+
class CoGroupOperatorWhere(object):
def __init__(self, env, info):
@@ -609,6 +645,18 @@ def where(self, *fields):
:param fields: The indexes of the Tuple fields of the first co-grouped DataSets that should be used as keys.
:return: An incomplete CoGroup transformation.
"""
+ f = None
+ if isinstance(fields[0], TYPES.FunctionType):
+ f = lambda x: (fields[0](x),)
+ if isinstance(fields[0], KeySelectorFunction):
+ f = lambda x: (fields[0].get_key(x),)
+ if f is None:
+ f = lambda x: tuple([x[key] for key in fields])
+
+ new_parent_set = self._info.parent_set.map(lambda x: (f(x), x))
+ new_parent_set._info.types = _createKeyValueTypeInfo(len(fields))
+ self._info.parent = new_parent_set._info
+ self._info.parent.children.append(self._info)
self._info.key1 = fields
return CoGroupOperatorTo(self._env, self._info)
@@ -628,6 +676,18 @@ def equal_to(self, *fields):
:param fields: The indexes of the Tuple fields of the second co-grouped DataSet that should be used as keys.
:return: An incomplete CoGroup transformation.
"""
+ f = None
+ if isinstance(fields[0], TYPES.FunctionType):
+ f = lambda x: (fields[0](x),)
+ if isinstance(fields[0], KeySelectorFunction):
+ f = lambda x: (fields[0].get_key(x),)
+ if f is None:
+ f = lambda x: tuple([x[key] for key in fields])
+
+ new_other_set = self._info.other_set.map(lambda x: (f(x), x))
+ new_other_set._info.types = _createKeyValueTypeInfo(len(fields))
+ self._info.other = new_other_set._info
+ self._info.other.children.append(self._info)
self._info.key2 = fields
return CoGroupOperatorUsing(self._env, self._info)
@@ -637,7 +697,7 @@ def __init__(self, env, info):
self._env = env
self._info = info
- def using(self, operator, types):
+ def using(self, operator):
"""
Finalizes a CoGroup transformation.
@@ -645,7 +705,6 @@ def using(self, operator, types):
Each CoGroupFunction call returns an arbitrary number of keys.
:param operator: The CoGroupFunction that is called for all groups of elements with identical keys.
- :param types: The type of the resulting DataSet.
:return:An CoGroupOperator that represents the co-grouped result DataSet.
"""
if isinstance(operator, TYPES.FunctionType):
@@ -653,11 +712,12 @@ def using(self, operator, types):
operator = CoGroupFunction()
operator.co_group = f
new_set = OperatorSet(self._env, self._info)
+ self._info.key1 = tuple([x for x in range(len(self._info.key1))])
+ self._info.key2 = tuple([x for x in range(len(self._info.key2))])
operator._keys1 = self._info.key1
operator._keys2 = self._info.key2
self._info.operator = operator
- self._info.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__)
- self._info.types = types
+ self._info.types = _createArrayTypeInfo()
self._info.name = "PythonCoGroup"
self._env._sets.append(self._info)
return new_set
@@ -679,7 +739,19 @@ def where(self, *fields):
:return:An incomplete Join transformation.
"""
- self._info.key1 = fields
+ f = None
+ if isinstance(fields[0], TYPES.FunctionType):
+ f = lambda x: (fields[0](x),)
+ if isinstance(fields[0], KeySelectorFunction):
+ f = lambda x: (fields[0].get_key(x),)
+ if f is None:
+ f = lambda x: tuple([x[key] for key in fields])
+
+ new_parent_set = self._info.parent_set.map(lambda x: (f(x), x))
+ new_parent_set._info.types = _createKeyValueTypeInfo(len(fields))
+ self._info.parent = new_parent_set._info
+ self._info.parent.children.append(self._info)
+ self._info.key1 = tuple([x for x in range(len(fields))])
return JoinOperatorTo(self._env, self._info)
@@ -698,81 +770,115 @@ def equal_to(self, *fields):
:param fields:The indexes of the Tuple fields of the second join DataSet that should be used as keys.
:return:An incomplete Join Transformation.
"""
- self._info.key2 = fields
+ f = None
+ if isinstance(fields[0], TYPES.FunctionType):
+ f = lambda x: (fields[0](x),)
+ if isinstance(fields[0], KeySelectorFunction):
+ f = lambda x: (fields[0].get_key(x),)
+ if f is None:
+ f = lambda x: tuple([x[key] for key in fields])
+
+ new_other_set = self._info.other_set.map(lambda x: (f(x), x))
+ new_other_set._info.types = _createKeyValueTypeInfo(len(fields))
+ self._info.other = new_other_set._info
+ self._info.other.children.append(self._info)
+ self._info.key2 = tuple([x for x in range(len(fields))])
+ self._env._sets.append(self._info)
return JoinOperator(self._env, self._info)
-class JoinOperatorProjection(DataSet):
+class Projector(DataSet):
def __init__(self, env, info):
- super(JoinOperatorProjection, self).__init__(env, info)
+ super(Projector, self).__init__(env, info)
def project_first(self, *fields):
"""
- Initiates a ProjectJoin transformation.
+ Initiates a Project transformation.
- Projects the first join input.
- If the first join input is a Tuple DataSet, fields can be selected by their index.
- If the first join input is not a Tuple DataSet, no parameters should be passed.
+ Projects the first input.
+ If the first input is a Tuple DataSet, fields can be selected by their index.
+ If the first input is not a Tuple DataSet, no parameters should be passed.
:param fields: The indexes of the selected fields.
- :return: An incomplete JoinProjection.
+ :return: An incomplete Projection.
"""
- self._info.projections.append(("first", fields))
+ for field in fields:
+ self._info.projections.append((0, field))
+ self._info.operator.map = lambda x : tuple([x[side][index] for side, index in self._info.projections])
return self
def project_second(self, *fields):
"""
- Initiates a ProjectJoin transformation.
+ Initiates a Project transformation.
- Projects the second join input.
- If the second join input is a Tuple DataSet, fields can be selected by their index.
- If the second join input is not a Tuple DataSet, no parameters should be passed.
+ Projects the second input.
+ If the second input is a Tuple DataSet, fields can be selected by their index.
+ If the second input is not a Tuple DataSet, no parameters should be passed.
:param fields: The indexes of the selected fields.
- :return: An incomplete JoinProjection.
+ :return: An incomplete Projection.
"""
- self._info.projections.append(("second", fields))
+ for field in fields:
+ self._info.projections.append((1, field))
+ self._info.operator.map = lambda x : tuple([x[side][index] for side, index in self._info.projections])
return self
-class JoinOperator(DataSet):
- def __init__(self, env, info):
- super(JoinOperator, self).__init__(env, info)
+class Projectable:
+ def __init__(self):
+ pass
def project_first(self, *fields):
"""
- Initiates a ProjectJoin transformation.
+ Initiates a Project transformation.
- Projects the first join input.
- If the first join input is a Tuple DataSet, fields can be selected by their index.
- If the first join input is not a Tuple DataSet, no parameters should be passed.
+ Projects the first input.
+ If the first input is a Tuple DataSet, fields can be selected by their index.
+ If the first input is not a Tuple DataSet, no parameters should be passed.
:param fields: The indexes of the selected fields.
- :return: An incomplete JoinProjection.
+ :return: An incomplete Projection.
"""
- return JoinOperatorProjection(self._env, self._info).project_first(*fields)
+ return Projectable._createProjector(self._env, self._info).project_first(*fields)
def project_second(self, *fields):
"""
- Initiates a ProjectJoin transformation.
+ Initiates a Project transformation.
- Projects the second join input.
- If the second join input is a Tuple DataSet, fields can be selected by their index.
- If the second join input is not a Tuple DataSet, no parameters should be passed.
+ Projects the second input.
+ If the second input is a Tuple DataSet, fields can be selected by their index.
+ If the second input is not a Tuple DataSet, no parameters should be passed.
:param fields: The indexes of the selected fields.
- :return: An incomplete JoinProjection.
+ :return: An incomplete Projection.
"""
- return JoinOperatorProjection(self._env, self._info).project_second(*fields)
+ return Projectable._createProjector(self._env, self._info).project_second(*fields)
+
+ @staticmethod
+ def _createProjector(env, info):
+ child = OperationInfo()
+ child_set = Projector(env, child)
+ child.identifier = _Identifier.MAP
+ child.operator = MapFunction()
+ child.parent = info
+ child.types = _createArrayTypeInfo()
+ child.name = "Projector"
+ info.children.append(child)
+ env._sets.append(child)
+ return child_set
+
- def using(self, operator, types):
+class JoinOperator(DataSet, Projectable):
+ def __init__(self, env, info):
+ super(JoinOperator, self).__init__(env, info)
+
+ def using(self, operator):
"""
Finalizes a Join transformation.
Applies a JoinFunction to each pair of joined elements. Each JoinFunction call returns exactly one element.
:param operator:The JoinFunction that is called for each pair of joined elements.
- :param types:
:return:An Set that represents the joined result DataSet.
"""
if isinstance(operator, TYPES.FunctionType):
@@ -780,84 +886,23 @@ def using(self, operator, types):
operator = JoinFunction()
operator.join = f
self._info.operator = operator
- self._info.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__)
- self._info.types = types
+ self._info.types = _createArrayTypeInfo()
self._info.name = "PythonJoin"
- self._env._sets.append(self._info)
+ self._info.uses_udf = True
return OperatorSet(self._env, self._info)
-class CrossOperatorProjection(DataSet):
- def __init__(self, env, info):
- super(CrossOperatorProjection, self).__init__(env, info)
-
- def project_first(self, *fields):
- """
- Initiates a ProjectCross transformation.
-
- Projects the first join input.
- If the first join input is a Tuple DataSet, fields can be selected by their index.
- If the first join input is not a Tuple DataSet, no parameters should be passed.
-
- :param fields: The indexes of the selected fields.
- :return: An incomplete CrossProjection.
- """
- self._info.projections.append(("first", fields))
- return self
-
- def project_second(self, *fields):
- """
- Initiates a ProjectCross transformation.
-
- Projects the second join input.
- If the second join input is a Tuple DataSet, fields can be selected by their index.
- If the second join input is not a Tuple DataSet, no parameters should be passed.
-
- :param fields: The indexes of the selected fields.
- :return: An incomplete CrossProjection.
- """
- self._info.projections.append(("second", fields))
- return self
-
-
-class CrossOperator(DataSet):
+class CrossOperator(DataSet, Projectable):
def __init__(self, env, info):
super(CrossOperator, self).__init__(env, info)
- def project_first(self, *fields):
- """
- Initiates a ProjectCross transformation.
-
- Projects the first join input.
- If the first join input is a Tuple DataSet, fields can be selected by their index.
- If the first join input is not a Tuple DataSet, no parameters should be passed.
-
- :param fields: The indexes of the selected fields.
- :return: An incomplete CrossProjection.
- """
- return CrossOperatorProjection(self._env, self._info).project_first(*fields)
-
- def project_second(self, *fields):
- """
- Initiates a ProjectCross transformation.
-
- Projects the second join input.
- If the second join input is a Tuple DataSet, fields can be selected by their index.
- If the second join input is not a Tuple DataSet, no parameters should be passed.
-
- :param fields: The indexes of the selected fields.
- :return: An incomplete CrossProjection.
- """
- return CrossOperatorProjection(self._env, self._info).project_second(*fields)
-
- def using(self, operator, types):
+ def using(self, operator):
"""
Finalizes a Cross transformation.
Applies a CrossFunction to each pair of joined elements. Each CrossFunction call returns exactly one element.
:param operator:The CrossFunction that is called for each pair of joined elements.
- :param types: The type of the resulting DataSet.
:return:An Set that represents the joined result DataSet.
"""
if isinstance(operator, TYPES.FunctionType):
@@ -865,7 +910,7 @@ def using(self, operator, types):
operator = CrossFunction()
operator.cross = f
self._info.operator = operator
- self._info.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__)
- self._info.types = types
+ self._info.types = _createArrayTypeInfo()
self._info.name = "PythonCross"
+ self._info.uses_udf = True
return OperatorSet(self._env, self._info)
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py
index a2279c3e763e7..4f2c5e33de9dc 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py
@@ -159,8 +159,8 @@ def execute(self, local=False, debug=False):
if plan_mode:
port = int(sys.stdin.readline().rstrip('\n'))
self._connection = Connection.PureTCPConnection(port)
- self._iterator = Iterator.TypedIterator(self._connection, self)
- self._collector = Collector.TypedCollector(self._connection, self)
+ self._iterator = Iterator.PlanIterator(self._connection, self)
+ self._collector = Collector.PlanCollector(self._connection, self)
self._send_plan()
result = self._receive_result()
self._connection.close()
@@ -175,13 +175,13 @@ def execute(self, local=False, debug=False):
input_path = sys.stdin.readline().rstrip('\n')
output_path = sys.stdin.readline().rstrip('\n')
+ used_set = None
operator = None
for set in self._sets:
if set.id == id:
+ used_set = set
operator = set.operator
- if set.id == -id:
- operator = set.combineop
- operator._configure(input_path, output_path, port, self)
+ operator._configure(input_path, output_path, port, self, used_set)
operator._go()
operator._close()
sys.stdout.flush()
@@ -211,7 +211,7 @@ def _find_chains(self):
if child_type in chainable:
parent = child.parent
if parent.operator is not None and len(parent.children) == 1 and len(parent.sinks) == 0:
- parent.operator._chain(child.operator)
+ parent.chained_info = child
parent.name += " -> " + child.name
parent.types = child.types
for grand_child in child.children:
@@ -297,11 +297,8 @@ def _send_operations(self):
break
if case(_Identifier.CROSS, _Identifier.CROSSH, _Identifier.CROSST):
collect(set.other.id)
+ collect(set.uses_udf)
collect(set.types)
- collect(len(set.projections))
- for p in set.projections:
- collect(p[0])
- collect(p[1])
collect(set.name)
break
if case(_Identifier.REDUCE, _Identifier.GROUPREDUCE):
@@ -312,11 +309,8 @@ def _send_operations(self):
collect(set.key1)
collect(set.key2)
collect(set.other.id)
+ collect(set.uses_udf)
collect(set.types)
- collect(len(set.projections))
- for p in set.projections:
- collect(p[0])
- collect(p[1])
collect(set.name)
break
if case(_Identifier.MAP, _Identifier.MAPPARTITION, _Identifier.FLATMAP, _Identifier.FILTER):
diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/OperationInfo.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/OperationInfo.py
index c47fab57d7c42..3605d7f418384 100644
--- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/OperationInfo.py
+++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/OperationInfo.py
@@ -23,6 +23,8 @@ def __init__(self, info=None):
if info is None:
self.parent = None
self.other = None
+ self.parent_set = None
+ self.other_set = None
self.identifier = None
self.field = None
self.order = None
@@ -31,6 +33,7 @@ def __init__(self, info=None):
self.key2 = None
self.types = None
self.operator = None
+ self.uses_udf = False
self.name = None
self.delimiter_line = "\n"
self.delimiter_field = ","
@@ -43,6 +46,7 @@ def __init__(self, info=None):
self.bcvars = []
self.id = None
self.to_err = False
+ self.chained_info = None
else:
self.__dict__.update(info.__dict__)
diff --git a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py
index c9bc404265bb2..3e718d317240d 100644
--- a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py
+++ b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py
@@ -22,7 +22,8 @@
from flink.functions.MapPartitionFunction import MapPartitionFunction
from flink.functions.ReduceFunction import ReduceFunction
from flink.functions.GroupReduceFunction import GroupReduceFunction
-from flink.plan.Constants import INT, STRING, FLOAT, BOOL, BYTES, CUSTOM, Order, WriteMode
+from flink.plan.Constants import Order, WriteMode
+from flink.plan.Constants import INT, STRING
import struct
#Utilities
@@ -73,7 +74,7 @@ def map_partition(self, iterator, collector):
d4 = env.from_elements((1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False))
- d5 = env.from_elements((4.4, 4.3, 1), (4.3, 4.4, 1), (4.2, 4.1, 3), (4.1, 4.1, 3))
+ d5 = env.from_elements((1, 2.4), (1, 3.7), (1, 0.4), (1, 5.4))
d6 = env.from_elements(1, 1, 12)
@@ -89,19 +90,19 @@ def map_partition(self, iterator, collector):
#Types
env.from_elements(bytearray(b"hello"), bytearray(b"world"))\
- .map(Id(), BYTES).map_partition(Verify([bytearray(b"hello"), bytearray(b"world")], "Byte"), STRING).output()
+ .map(Id()).map_partition(Verify([bytearray(b"hello"), bytearray(b"world")], "Byte")).output()
env.from_elements(1, 2, 3, 4, 5)\
- .map(Id(), INT).map_partition(Verify([1,2,3,4,5], "Int"), STRING).output()
+ .map(Id()).map_partition(Verify([1,2,3,4,5], "Int")).output()
env.from_elements(True, True, False)\
- .map(Id(), BOOL).map_partition(Verify([True, True, False], "Bool"), STRING).output()
+ .map(Id()).map_partition(Verify([True, True, False], "Bool")).output()
env.from_elements(1.4, 1.7, 12312.23)\
- .map(Id(), FLOAT).map_partition(Verify([1.4, 1.7, 12312.23], "Float"), STRING).output()
+ .map(Id()).map_partition(Verify([1.4, 1.7, 12312.23], "Float")).output()
env.from_elements("hello", "world")\
- .map(Id(), STRING).map_partition(Verify(["hello", "world"], "String"), STRING).output()
+ .map(Id()).map_partition(Verify(["hello", "world"], "String")).output()
#Custom Serialization
class Ext(MapPartitionFunction):
@@ -125,16 +126,16 @@ def deserialize(self, read):
env.register_type(MyObj, MySerializer(), MyDeserializer())
env.from_elements(MyObj(2), MyObj(4)) \
- .map(Id(), CUSTOM).map_partition(Ext(), INT) \
- .map_partition(Verify([2, 4], "CustomTypeSerialization"), STRING).output()
+ .map(Id()).map_partition(Ext()) \
+ .map_partition(Verify([2, 4], "CustomTypeSerialization")).output()
#Map
class Mapper(MapFunction):
def map(self, value):
return value * value
d1 \
- .map((lambda x: x * x), INT).map(Mapper(), INT) \
- .map_partition(Verify([1, 1296, 20736], "Map"), STRING).output()
+ .map((lambda x: x * x)).map(Mapper()) \
+ .map_partition(Verify([1, 1296, 20736], "Map")).output()
#FlatMap
class FlatMap(FlatMapFunction):
@@ -142,8 +143,8 @@ def flat_map(self, value, collector):
collector.collect(value)
collector.collect(value * 2)
d1 \
- .flat_map(FlatMap(), INT).flat_map(FlatMap(), INT) \
- .map_partition(Verify([1, 2, 2, 4, 6, 12, 12, 24, 12, 24, 24, 48], "FlatMap"), STRING).output()
+ .flat_map(FlatMap()).flat_map(FlatMap()) \
+ .map_partition(Verify([1, 2, 2, 4, 6, 12, 12, 24, 12, 24, 24, 48], "FlatMap")).output()
#MapPartition
class MapPartition(MapPartitionFunction):
@@ -151,8 +152,8 @@ def map_partition(self, iterator, collector):
for value in iterator:
collector.collect(value * 2)
d1 \
- .map_partition(MapPartition(), INT) \
- .map_partition(Verify([2, 12, 24], "MapPartition"), STRING).output()
+ .map_partition(MapPartition()) \
+ .map_partition(Verify([2, 12, 24], "MapPartition")).output()
#Filter
class Filter(FilterFunction):
@@ -164,7 +165,7 @@ def filter(self, value):
return value > self.limit
d1 \
.filter(Filter(5)).filter(Filter(8)) \
- .map_partition(Verify([12], "Filter"), STRING).output()
+ .map_partition(Verify([12], "Filter")).output()
#Reduce
class Reduce(ReduceFunction):
@@ -176,7 +177,11 @@ def reduce(self, value1, value2):
return (value1[0] + value2[0], value1[1] + value2[1], value1[2], value1[3] or value2[3])
d1 \
.reduce(Reduce()) \
- .map_partition(Verify([19], "AllReduce"), STRING).output()
+ .map_partition(Verify([19], "AllReduce")).output()
+
+ d4 \
+ .group_by(2).reduce(Reduce2()) \
+ .map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "GroupedReduce")).output()
#GroupReduce
class GroupReduce(GroupReduceFunction):
@@ -193,9 +198,31 @@ class GroupReduce2(GroupReduceFunction):
def reduce(self, iterator, collector):
for value in iterator:
collector.collect(value)
+
+ d4 \
+ .reduce_group(GroupReduce2()) \
+ .map_partition(Verify([(1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "AllGroupReduce")).output()
+ d4 \
+ .group_by(lambda x: x[2]).reduce_group(GroupReduce(), combinable=False) \
+ .map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "GroupReduceWithKeySelector")).output()
d4 \
- .group_by(2).reduce_group(GroupReduce(), (INT, FLOAT, STRING, BOOL), combinable=False) \
- .map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "AllGroupReduce"), STRING).output()
+ .group_by(2).reduce_group(GroupReduce()) \
+ .map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "GroupReduce")).output()
+ d5 \
+ .group_by(0).sort_group(1, Order.ASCENDING).reduce_group(GroupReduce2(), combinable=True) \
+ .map_partition(Verify([(1, 0.4), (1, 2.4), (1, 3.7), (1, 5.4)], "SortedGroupReduceAsc")).output()
+ d5 \
+ .group_by(0).sort_group(1, Order.DESCENDING).reduce_group(GroupReduce2(), combinable=True) \
+ .map_partition(Verify([(1, 5.4), (1, 3.7), (1, 2.4), (1, 0.4)], "SortedGroupReduceDes")).output()
+ d5 \
+ .group_by(lambda x: x[0]).sort_group(1, Order.DESCENDING).reduce_group(GroupReduce2(), combinable=True) \
+ .map_partition(Verify([(1, 5.4), (1, 3.7), (1, 2.4), (1, 0.4)], "SortedGroupReduceKeySelG")).output()
+ d5 \
+ .group_by(0).sort_group(lambda x: x[1], Order.DESCENDING).reduce_group(GroupReduce2(), combinable=True) \
+ .map_partition(Verify([(1, 5.4), (1, 3.7), (1, 2.4), (1, 0.4)], "SortedGroupReduceKeySelS")).output()
+ d5 \
+ .group_by(lambda x: x[0]).sort_group(lambda x: x[1], Order.DESCENDING).reduce_group(GroupReduce2(), combinable=True) \
+ .map_partition(Verify([(1, 5.4), (1, 3.7), (1, 2.4), (1, 0.4)], "SortedGroupReduceKeySelGS")).output()
#Execution
env.set_parallelism(1)
diff --git a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main2.py b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main2.py
index 56e325059df33..6bf1fabe9beac 100644
--- a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main2.py
+++ b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main2.py
@@ -22,7 +22,6 @@
from flink.functions.CrossFunction import CrossFunction
from flink.functions.JoinFunction import JoinFunction
from flink.functions.CoGroupFunction import CoGroupFunction
-from flink.plan.Constants import BOOL, INT, FLOAT, STRING
#Utilities
@@ -85,28 +84,32 @@ def join(self, value1, value2):
else:
return value2[0] + str(value1[1])
d2 \
- .join(d3).where(2).equal_to(0).using(Join(), STRING) \
- .map_partition(Verify(["hello1", "world0.4"], "Join"), STRING).output()
+ .join(d3).where(2).equal_to(0).using(Join()) \
+ .map_partition(Verify(["hello1", "world0.4"], "Join")).output()
+ d2 \
+ .join(d3).where(lambda x: x[2]).equal_to(0).using(Join()) \
+ .map_partition(Verify(["hello1", "world0.4"], "JoinWithKeySelector")).output()
d2 \
.join(d3).where(2).equal_to(0).project_first(0, 3).project_second(0) \
- .map_partition(Verify([(1, True, "hello"), (2, False, "world")], "Project Join"), STRING).output()
+ .map_partition(Verify([(1, True, "hello"), (2, False, "world")], "Project Join")).output()
d2 \
.join(d3).where(2).equal_to(0) \
- .map_partition(Verify([((1, 0.5, "hello", True), ("hello",)), ((2, 0.4, "world", False), ("world",))], "Default Join"), STRING).output()
+ .map_partition(Verify([((1, 0.5, "hello", True), ("hello",)), ((2, 0.4, "world", False), ("world",))], "Default Join")).output()
#Cross
class Cross(CrossFunction):
def cross(self, value1, value2):
return (value1, value2[3])
d1 \
- .cross(d2).using(Cross(), (INT, BOOL)) \
- .map_partition(Verify([(1, True), (1, False), (6, True), (6, False), (12, True), (12, False)], "Cross"), STRING).output()
+ .cross(d2).using(Cross()) \
+ .map_partition(Verify([(1, True), (1, False), (6, True), (6, False), (12, True), (12, False)], "Cross")).output()
d1 \
.cross(d3) \
- .map_partition(Verify([(1, ("hello",)), (1, ("world",)), (6, ("hello",)), (6, ("world",)), (12, ("hello",)), (12, ("world",))], "Default Cross"), STRING).output()
+ .map_partition(Verify([(1, ("hello",)), (1, ("world",)), (6, ("hello",)), (6, ("world",)), (12, ("hello",)), (12, ("world",))], "Default Cross")).output()
+
d2 \
.cross(d3).project_second(0).project_first(0, 1) \
- .map_partition(Verify([("hello", 1, 0.5), ("world", 1, 0.5), ("hello", 2, 0.4), ("world", 2, 0.4)], "Project Cross"), STRING).output()
+ .map_partition(Verify([("hello", 1, 0.5), ("world", 1, 0.5), ("hello", 2, 0.4), ("world", 2, 0.4)], "Project Cross")).output()
#CoGroup
class CoGroup(CoGroupFunction):
@@ -114,8 +117,8 @@ def co_group(self, iterator1, iterator2, collector):
while iterator1.has_next() and iterator2.has_next():
collector.collect((iterator1.next(), iterator2.next()))
d4 \
- .co_group(d5).where(0).equal_to(2).using(CoGroup(), ((INT, FLOAT, STRING, BOOL), (FLOAT, FLOAT, INT))) \
- .map_partition(Verify([((1, 0.5, "hello", True), (4.4, 4.3, 1)), ((1, 0.4, "hello", False), (4.3, 4.4, 1))], "CoGroup"), STRING).output()
+ .co_group(d5).where(0).equal_to(2).using(CoGroup()) \
+ .map_partition(Verify([((1, 0.5, "hello", True), (4.4, 4.3, 1)), ((1, 0.4, "hello", False), (4.3, 4.4, 1))], "CoGroup")).output()
#Broadcast
class MapperBcv(MapFunction):
@@ -123,22 +126,23 @@ def map(self, value):
factor = self.context.get_broadcast_variable("test")[0][0]
return value * factor
d1 \
- .map(MapperBcv(), INT).with_broadcast_set("test", d2) \
- .map_partition(Verify([1, 6, 12], "Broadcast"), STRING).output()
+ .map(MapperBcv()).with_broadcast_set("test", d2) \
+ .map_partition(Verify([1, 6, 12], "Broadcast")).output()
#Misc
class Mapper(MapFunction):
def map(self, value):
return value * value
d1 \
- .map(Mapper(), INT).map((lambda x: x * x), INT) \
- .map_partition(Verify([1, 1296, 20736], "Chained Lambda"), STRING).output()
+ .map(Mapper()).map((lambda x: x * x)) \
+ .map_partition(Verify([1, 1296, 20736], "Chained Lambda")).output()
d2 \
.project(0, 1, 2) \
- .map_partition(Verify([(1, 0.5, "hello"), (2, 0.4, "world")], "Project"), STRING).output()
+ .map_partition(Verify([(1, 0.5, "hello"), (2, 0.4, "world")], "Project")).output()
d2 \
.union(d4) \
- .map_partition(Verify2([(1, 0.5, "hello", True), (2, 0.4, "world", False), (1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "Union"), STRING).output()
+ .map_partition(Verify2([(1, 0.5, "hello", True), (2, 0.4, "world", False), (1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "Union")).output()
+
#Execution
env.set_parallelism(1)
diff --git a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_type_deduction.py b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_type_deduction.py
deleted file mode 100644
index 1ff3f923bc734..0000000000000
--- a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_type_deduction.py
+++ /dev/null
@@ -1,73 +0,0 @@
-################################################################################
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-################################################################################
-from flink.plan.Environment import get_environment
-from flink.plan.Constants import BOOL, STRING
-from flink.functions.MapPartitionFunction import MapPartitionFunction
-
-
-class Verify(MapPartitionFunction):
- def __init__(self, msg):
- super(Verify, self).__init__()
- self.msg = msg
-
- def map_partition(self, iterator, collector):
- if self.msg is None:
- return
- else:
- raise Exception("Type Deduction failed: " + self.msg)
-
-if __name__ == "__main__":
- env = get_environment()
-
- d1 = env.from_elements(("hello", 4, 3.2, True))
-
- d2 = env.from_elements("world")
-
- direct_from_source = d1.filter(lambda x:True)
-
- msg = None
-
- if direct_from_source._info.types != ("hello", 4, 3.2, True):
- msg = "Error deducting type directly from source."
-
- from_common_udf = d1.map(lambda x: x[3], BOOL).filter(lambda x:True)
-
- if from_common_udf._info.types != BOOL:
- msg = "Error deducting type from common udf."
-
- through_projection = d1.project(3, 2).filter(lambda x:True)
-
- if through_projection._info.types != (True, 3.2):
- msg = "Error deducting type through projection."
-
- through_default_op = d1.cross(d2).filter(lambda x:True)
-
- if through_default_op._info.types != (("hello", 4, 3.2, True), "world"):
- msg = "Error deducting type through default J/C." +str(through_default_op._info.types)
-
- through_prj_op = d1.cross(d2).project_first(1, 0).project_second().project_first(3, 2).filter(lambda x:True)
-
- if through_prj_op._info.types != (4, "hello", "world", True, 3.2):
- msg = "Error deducting type through projection J/C. "+str(through_prj_op._info.types)
-
-
- env = get_environment()
-
- env.from_elements("dummy").map_partition(Verify(msg), STRING).output()
-
- env.execute(local=True)