From ab8470714715cb02b32e3299f36ea64ebe85695e Mon Sep 17 00:00:00 2001 From: zentol Date: Wed, 20 Jan 2016 14:50:53 +0100 Subject: [PATCH] [FLINK-2501] [py] Remove the need to specify types for transformations Full changelog: Major changes =============== - Users no longer have to supply information about types - Values are now stored as byte arrays on the Java side in * a plain byte[] most of the time, * a T2 within a join/cross * a T2, b[]> within keyed operations. - Every value contains information about its type at the beginning of each byte array. - Implemented KeySelectors Minor =============== - improved error messages in several places - defaultable operations now use a "usesUDF" flag - reshuffled type ID's; tuple type encoded as 1-25 - broadcast variables are now sent via the tcp socket - ProjectJoin/-Cross now executes projection on python side Java --------------- - Sort field now stored as String, continuation of FLINK-2431 - object->byte[] serializer code moved into separate utility class Python --------------- - Fixed NullSerializer not taking a read method argument - Serializer/Deserializer interface added - Refactored DataSet structure * Set and ReduceSet merged into DataSet - configure() now takes an OperationInfo argument - Simplified GroupReduce tests - removed unused Function._open() - simplified chaining setup - most functions now use super.configure() --- docs/apis/batch/dataset_transformations.md | 20 +- docs/apis/batch/python.md | 82 ++-- .../flink/python/api/PythonOperationInfo.java | 51 +- .../flink/python/api/PythonPlanBinder.java | 77 +-- .../python/api/functions/PythonCoGroup.java | 2 +- .../api/functions/PythonMapPartition.java | 3 +- .../{ => util}/IdentityGroupReduce.java | 5 +- .../api/functions/util/KeyDiscarder.java | 29 ++ .../functions/util/NestedKeyDiscarder.java | 30 ++ .../api/functions/util/SerializerMap.java | 32 ++ .../functions/util/StringDeserializerMap.java | 26 + .../util/StringTupleDeserializerMap.java | 27 ++ .../api/streaming/data/PythonReceiver.java | 185 ++----- .../api/streaming/data/PythonSender.java | 264 ++-------- .../api/streaming/data/PythonStreamer.java | 76 +-- .../streaming/plan/PythonPlanReceiver.java | 199 ++++++-- .../api/streaming/plan/PythonPlanSender.java | 87 +--- .../streaming/util/SerializationUtils.java | 283 +++++++++++ .../python/api/flink/connection/Collector.py | 181 +++---- .../python/api/flink/connection/Connection.py | 8 + .../python/api/flink/connection/Constants.py | 23 +- .../python/api/flink/connection/Iterator.py | 194 +++++--- .../python/api/flink/example/TPCHQuery10.py | 2 +- .../python/api/flink/example/TPCHQuery3.py | 4 +- .../api/flink/example/TriangleEnumeration.py | 14 +- .../api/flink/example/WebLogAnalysis.py | 10 +- .../python/api/flink/example/WordCount.py | 5 +- .../api/flink/functions/CoGroupFunction.py | 7 +- .../python/api/flink/functions/Function.py | 57 +-- .../flink/functions/GroupReduceFunction.py | 15 +- .../flink/functions/KeySelectorFunction.py | 28 ++ .../api/flink/functions/ReduceFunction.py | 15 +- .../flink/python/api/flink/plan/Constants.py | 15 +- .../flink/python/api/flink/plan/DataSet.py | 451 ++++++++++-------- .../python/api/flink/plan/Environment.py | 22 +- .../python/api/flink/plan/OperationInfo.py | 4 + .../org/apache/flink/python/api/test_main.py | 65 ++- .../org/apache/flink/python/api/test_main2.py | 38 +- .../flink/python/api/test_type_deduction.py | 73 --- 39 files changed, 1426 insertions(+), 1283 deletions(-) rename flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/{ => util}/IdentityGroupReduce.java (91%) create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/KeyDiscarder.java create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/NestedKeyDiscarder.java create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/SerializerMap.java create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringDeserializerMap.java create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringTupleDeserializerMap.java create mode 100644 flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/util/SerializationUtils.java create mode 100644 flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/KeySelectorFunction.py delete mode 100644 flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_type_deduction.py diff --git a/docs/apis/batch/dataset_transformations.md b/docs/apis/batch/dataset_transformations.md index b9d7f0c3929da..7315ec2eaf5af 100644 --- a/docs/apis/batch/dataset_transformations.md +++ b/docs/apis/batch/dataset_transformations.md @@ -71,7 +71,7 @@ val intSums = intPairs.map { pair => pair._1 + pair._2 }
~~~python - intSums = intPairs.map(lambda x: sum(x), INT) + intSums = intPairs.map(lambda x: sum(x)) ~~~
@@ -115,7 +115,7 @@ val words = textLines.flatMap { _.split(" ") }
~~~python - words = lines.flat_map(lambda x,c: [line.split() for line in x], STRING) + words = lines.flat_map(lambda x,c: [line.split() for line in x]) ~~~
@@ -163,7 +163,7 @@ val counts = texLines.mapPartition { in => Some(in.size) }
~~~python - counts = lines.map_partition(lambda x,c: [sum(1 for _ in x)], INT) + counts = lines.map_partition(lambda x,c: [sum(1 for _ in x)]) ~~~
@@ -459,7 +459,7 @@ Works analogous to grouping by Case Class fields in *Reduce* transformations. for key in dic.keys(): collector.collect(key) - output = data.group_by(0).reduce_group(DistinctReduce(), STRING) + output = data.group_by(0).reduce_group(DistinctReduce()) ~~~ @@ -539,7 +539,7 @@ val output = input.groupBy(0).sortGroup(1, Order.ASCENDING).reduceGroup { for key in dic.keys(): collector.collect(key) - output = data.group_by(0).sort_group(1, Order.ASCENDING).reduce_group(DistinctReduce(), STRING) + output = data.group_by(0).sort_group(1, Order.ASCENDING).reduce_group(DistinctReduce()) ~~~ @@ -644,7 +644,7 @@ class MyCombinableGroupReducer def combine(self, iterator, collector): return self.reduce(iterator, collector) -data.reduce_group(GroupReduce(), (STRING, INT, FLOAT), combinable=True) +data.reduce_group(GroupReduce(), combinable=True) ~~~ @@ -864,7 +864,7 @@ val output = input.reduceGroup(new MyGroupReducer())
~~~python - output = data.reduce_group(MyGroupReducer(), ... ) + output = data.reduce_group(MyGroupReducer()) ~~~
@@ -1223,7 +1223,7 @@ val weightedRatings = ratings.join(weights).where("category").equalTo(0) { weightedRatings = ratings.join(weights).where(0).equal_to(0). \ - with(new PointWeighter(), (STRING, FLOAT)); + with(new PointWeighter()); ~~~ @@ -1719,7 +1719,7 @@ val distances = coords1.cross(coords2) { def cross(self, c1, c2): return (c1[0], c2[0], sqrt(pow(c1[1] - c2.[1], 2) + pow(c1[2] - c2[2], 2))) - distances = coords1.cross(coords2).using(Euclid(), (INT,INT,FLOAT)) + distances = coords1.cross(coords2).using(Euclid()) ~~~ #### Cross with Projection @@ -1878,7 +1878,7 @@ val output = iVals.coGroup(dVals).where(0).equalTo(0) { collector.collect(value[1] * i) - output = ivals.co_group(dvals).where(0).equal_to(0).using(CoGroup(), DOUBLE) + output = ivals.co_group(dvals).where(0).equal_to(0).using(CoGroup()) ~~~ diff --git a/docs/apis/batch/python.md b/docs/apis/batch/python.md index 74da97e933774..7e2060bc910d6 100644 --- a/docs/apis/batch/python.md +++ b/docs/apis/batch/python.md @@ -50,7 +50,6 @@ to run it locally. {% highlight python %} from flink.plan.Environment import get_environment -from flink.plan.Constants import INT, STRING from flink.functions.GroupReduceFunction import GroupReduceFunction class Adder(GroupReduceFunction): @@ -59,18 +58,17 @@ class Adder(GroupReduceFunction): count += sum([x[0] for x in iterator]) collector.collect((count, word)) -if __name__ == "__main__": - env = get_environment() - data = env.from_elements("Who's there?", - "I think I hear them. Stand, ho! Who's there?") +env = get_environment() +data = env.from_elements("Who's there?", + "I think I hear them. Stand, ho! Who's there?") - data \ - .flat_map(lambda x, c: [(1, word) for word in x.lower().split()], (INT, STRING)) \ - .group_by(1) \ - .reduce_group(Adder(), (INT, STRING), combinable=True) \ - .output() +data \ + .flat_map(lambda x, c: [(1, word) for word in x.lower().split()]) \ + .group_by(1) \ + .reduce_group(Adder(), combinable=True) \ + .output() - env.execute(local=True) +env.execute(local=True) {% endhighlight %} {% top %} @@ -78,8 +76,8 @@ if __name__ == "__main__": Program Skeleton ---------------- -As we already saw in the example, Flink programs look like regular python -programs with a `if __name__ == "__main__":` block. Each program consists of the same basic parts: +As we already saw in the example, Flink programs look like regular python programs. +Each program consists of the same basic parts: 1. Obtain an `Environment`, 2. Load/create the initial data, @@ -117,7 +115,7 @@ methods on DataSet with your own custom transformation function. For example, a map transformation looks like this: {% highlight python %} -data.map(lambda x: x*2, INT) +data.map(lambda x: x*2) {% endhighlight %} This will create a new DataSet by doubling every value in the original DataSet. @@ -197,7 +195,7 @@ examples.

Takes one element and produces one element.

{% highlight python %} -data.map(lambda x: x * 2, INT) +data.map(lambda x: x * 2) {% endhighlight %} @@ -208,8 +206,7 @@ data.map(lambda x: x * 2, INT)

Takes one element and produces zero, one, or more elements.

{% highlight python %} data.flat_map( - lambda x,c: [(1,word) for word in line.lower().split() for line in x], - (INT, STRING)) + lambda x,c: [(1,word) for word in line.lower().split() for line in x]) {% endhighlight %} @@ -221,7 +218,7 @@ data.flat_map( as an `Iterator` and can produce an arbitrary number of result values. The number of elements in each partition depends on the degree-of-parallelism and previous operations.

{% highlight python %} -data.map_partition(lambda x,c: [value * 2 for value in x], INT) +data.map_partition(lambda x,c: [value * 2 for value in x]) {% endhighlight %} @@ -260,7 +257,7 @@ class Adder(GroupReduceFunction): count += sum([x[0] for x in iterator) collector.collect((count, word)) -data.reduce_group(Adder(), (INT, STRING)) +data.reduce_group(Adder()) {% endhighlight %} @@ -392,24 +389,33 @@ They are also the only way to define an optional `combine` function for a reduce Lambda functions allow the easy insertion of one-liners. Note that a lambda function has to return an iterable, if the operation can return multiple values. (All functions receiving a collector argument) -Flink requires type information at the time when it prepares the program for execution -(when the main method of the program is called). This is done by passing an exemplary -object that has the desired type. This holds also for tuples. +{% top %} + +Data Types +---------- + +Flink's Python API currently only offers native support for primitive python types (int, float, bool, string) and byte arrays. +The type support can be extended by passing a serializer, deserializer and type class to the environment. {% highlight python %} -(INT, STRING) -{% endhighlight %} +class MyObj(object): + def __init__(self, i): + self.value = i -Would denote a tuple containing an int and a string. Note that for Operations that work strictly on tuples (like cross), no braces are required. -There are a few Constants defined in flink.plan.Constants that allow this in a more readable fashion. +class MySerializer(object): + def serialize(self, value): + return struct.pack(">i", value.value) -{% top %} -Data Types ----------- +class MyDeserializer(object): + def _deserialize(self, read): + i = struct.unpack(">i", read(4))[0] + return MyObj(i) + -Flink's Python API currently only supports primitive python types (int, float, bool, string) and byte arrays. +env.register_custom_type(MyObj, MySerializer(), MyDeserializer()) +{% endhighlight %} #### Tuples/Lists @@ -419,7 +425,7 @@ a fix number of fields of various types (up to 25). Every field of a tuple can b {% highlight python %} word_counts = env.from_elements(("hello", 1), ("world",2)) -counts = word_counts.map(lambda x: x[1], INT) +counts = word_counts.map(lambda x: x[1]) {% endhighlight %} When working with operators that require a Key for grouping or matching records, @@ -455,16 +461,16 @@ Collection-based: {% highlight python %} env = get_environment -# read text file from local files system +\# read text file from local files system localLiens = env.read_text("file:#/path/to/my/textfile") - read text file from a HDFS running at nnHost:nnPort +\# read text file from a HDFS running at nnHost:nnPort hdfsLines = env.read_text("hdfs://nnHost:nnPort/path/to/my/textfile") - read a CSV file with three fields +\# read a CSV file with three fields, schema defined using constants defined in flink.plan.Constants csvInput = env.read_csv("hdfs:///the/CSV/file", (INT, STRING, DOUBLE)) - create a set from some given elements +\# create a set from some given elements values = env.from_elements("Foo", "bar", "foobar", "fubar") {% endhighlight %} @@ -530,7 +536,7 @@ toBroadcast = env.from_elements(1, 2, 3) data = env.from_elements("a", "b") # 2. Broadcast the DataSet -data.map(MapperBcv(), INT).with_broadcast_set("bcv", toBroadcast) +data.map(MapperBcv()).with_broadcast_set("bcv", toBroadcast) {% endhighlight %} Make sure that the names (`bcv` in the previous example) match when registering and @@ -568,9 +574,9 @@ execution environment as follows: env = get_environment() env.set_degree_of_parallelism(3) -text.flat_map(lambda x,c: x.lower().split(), (INT, STRING)) \ +text.flat_map(lambda x,c: x.lower().split()) \ .group_by(1) \ - .reduce_group(Adder(), (INT, STRING), combinable=True) \ + .reduce_group(Adder(), combinable=True) \ .output() env.execute() diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonOperationInfo.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonOperationInfo.java index 6ecd6837e0bfd..30a7133f2a2d7 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonOperationInfo.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonOperationInfo.java @@ -32,10 +32,9 @@ public class PythonOperationInfo { public String[] keys2; //join/cogroup keys public TypeInformation types; //typeinformation about output type public AggregationEntry[] aggregates; - public ProjectionEntry[] projections; //projectFirst/projectSecond public Object[] values; public int count; - public int field; + public String field; public int[] fields; public Order order; public String path; @@ -46,6 +45,7 @@ public class PythonOperationInfo { public WriteMode writeMode; public boolean toError; public String name; + public boolean usesUDF; public PythonOperationInfo(PythonPlanStreamer streamer, Operation identifier) throws IOException { Object tmpType; @@ -127,7 +127,7 @@ public PythonOperationInfo(PythonPlanStreamer streamer, Operation identifier) th case REBALANCE: return; case SORT: - field = (Integer) streamer.getRecord(true); + field = "f0.f" + (Integer) streamer.getRecord(true); int encodedOrder = (Integer) streamer.getRecord(true); switch (encodedOrder) { case 0: @@ -162,15 +162,9 @@ public PythonOperationInfo(PythonPlanStreamer streamer, Operation identifier) th case CROSS_H: case CROSS_T: otherID = (Integer) streamer.getRecord(true); + usesUDF = (Boolean) streamer.getRecord(); tmpType = streamer.getRecord(); types = tmpType == null ? null : getForObject(tmpType); - int cProjectCount = (Integer) streamer.getRecord(true); - projections = new ProjectionEntry[cProjectCount]; - for (int x = 0; x < cProjectCount; x++) { - String side = (String) streamer.getRecord(); - int[] keys = toIntArray((Tuple) streamer.getRecord(true)); - projections[x] = new ProjectionEntry(ProjectionSide.valueOf(side.toUpperCase()), keys); - } name = (String) streamer.getRecord(); return; case REDUCE: @@ -185,15 +179,9 @@ public PythonOperationInfo(PythonPlanStreamer streamer, Operation identifier) th keys1 = normalizeKeys(streamer.getRecord(true)); keys2 = normalizeKeys(streamer.getRecord(true)); otherID = (Integer) streamer.getRecord(true); + usesUDF = (Boolean) streamer.getRecord(); tmpType = streamer.getRecord(); types = tmpType == null ? null : getForObject(tmpType); - int jProjectCount = (Integer) streamer.getRecord(true); - projections = new ProjectionEntry[jProjectCount]; - for (int x = 0; x < jProjectCount; x++) { - String side = (String) streamer.getRecord(); - int[] keys = toIntArray((Tuple) streamer.getRecord(true)); - projections[x] = new ProjectionEntry(ProjectionSide.valueOf(side.toUpperCase()), keys); - } name = (String) streamer.getRecord(); return; case MAPPARTITION: @@ -221,7 +209,6 @@ public String toString() { sb.append("Keys2: ").append(Arrays.toString(keys2)).append("\n"); sb.append("Keys: ").append(Arrays.toString(keys)).append("\n"); sb.append("Aggregates: ").append(Arrays.toString(aggregates)).append("\n"); - sb.append("Projections: ").append(Arrays.toString(projections)).append("\n"); sb.append("Count: ").append(count).append("\n"); sb.append("Field: ").append(field).append("\n"); sb.append("Order: ").append(order.toString()).append("\n"); @@ -260,26 +247,6 @@ public String toString() { } } - public static class ProjectionEntry { - public ProjectionSide side; - public int[] keys; - - public ProjectionEntry(ProjectionSide side, int[] keys) { - this.side = side; - this.keys = keys; - } - - @Override - public String toString() { - return side + " - " + Arrays.toString(keys); - } - } - - public enum ProjectionSide { - FIRST, - SECOND - } - public enum DatasizeHint { NONE, TINY, @@ -296,24 +263,24 @@ private static String[] normalizeKeys(Object keys) { if (tupleKeys.getField(0) instanceof Integer) { String[] stringKeys = new String[tupleKeys.getArity()]; for (int x = 0; x < stringKeys.length; x++) { - stringKeys[x] = "f" + (Integer) tupleKeys.getField(x); + stringKeys[x] = "f0.f" + (Integer) tupleKeys.getField(x); } return stringKeys; } if (tupleKeys.getField(0) instanceof String) { return tupleToStringArray(tupleKeys); } - throw new RuntimeException("Key argument contains field that is neither an int nor a String."); + throw new RuntimeException("Key argument contains field that is neither an int nor a String: " + tupleKeys); } if (keys instanceof int[]) { int[] intKeys = (int[]) keys; String[] stringKeys = new String[intKeys.length]; for (int x = 0; x < stringKeys.length; x++) { - stringKeys[x] = "f" + intKeys[x]; + stringKeys[x] = "f0.f" + intKeys[x]; } return stringKeys; } - throw new RuntimeException("Key argument is neither an int[] nor a Tuple."); + throw new RuntimeException("Key argument is neither an int[] nor a Tuple: " + keys.toString()); } private static int[] toIntArray(Object key) { diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java index 2e64a56349f7f..f07a975a319e1 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java @@ -28,10 +28,7 @@ import org.apache.flink.api.java.operators.AggregateOperator; import org.apache.flink.api.java.operators.CoGroupRawOperator; import org.apache.flink.api.java.operators.CrossOperator.DefaultCross; -import org.apache.flink.api.java.operators.CrossOperator.ProjectCross; import org.apache.flink.api.java.operators.Grouping; -import org.apache.flink.api.java.operators.JoinOperator.DefaultJoin; -import org.apache.flink.api.java.operators.JoinOperator.ProjectJoin; import org.apache.flink.api.java.operators.Keys; import org.apache.flink.api.java.operators.SortedGrouping; import org.apache.flink.api.java.operators.UdfOperator; @@ -42,14 +39,18 @@ import org.apache.flink.configuration.GlobalConfiguration; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; +import org.apache.flink.python.api.functions.util.NestedKeyDiscarder; +import org.apache.flink.python.api.functions.util.StringTupleDeserializerMap; import org.apache.flink.python.api.PythonOperationInfo.DatasizeHint; import static org.apache.flink.python.api.PythonOperationInfo.DatasizeHint.HUGE; import static org.apache.flink.python.api.PythonOperationInfo.DatasizeHint.NONE; import static org.apache.flink.python.api.PythonOperationInfo.DatasizeHint.TINY; -import org.apache.flink.python.api.PythonOperationInfo.ProjectionEntry; import org.apache.flink.python.api.functions.PythonCoGroup; -import org.apache.flink.python.api.functions.IdentityGroupReduce; +import org.apache.flink.python.api.functions.util.IdentityGroupReduce; import org.apache.flink.python.api.functions.PythonMapPartition; +import org.apache.flink.python.api.functions.util.KeyDiscarder; +import org.apache.flink.python.api.functions.util.SerializerMap; +import org.apache.flink.python.api.functions.util.StringDeserializerMap; import org.apache.flink.python.api.streaming.plan.PythonPlanStreamer; import org.apache.flink.runtime.filecache.FileCache; import org.slf4j.Logger; @@ -410,34 +411,35 @@ private void createCsvSource(PythonOperationInfo info) throws IOException { sets.put(info.setID, env.createInput(new TupleCsvInputFormat(new Path(info.path), info.lineDelimiter, info.fieldDelimiter, (TupleTypeInfo) info.types), info.types) - .name("CsvSource")); + .name("CsvSource").map(new SerializerMap()).name("CsvSourcePostStep")); } private void createTextSource(PythonOperationInfo info) throws IOException { - sets.put(info.setID, env.readTextFile(info.path).name("TextSource")); + sets.put(info.setID, env.readTextFile(info.path).name("TextSource").map(new SerializerMap()).name("TextSourcePostStep")); } private void createValueSource(PythonOperationInfo info) throws IOException { - sets.put(info.setID, env.fromElements(info.values).name("ValueSource")); + sets.put(info.setID, env.fromElements(info.values).name("ValueSource").map(new SerializerMap()).name("ValueSourcePostStep")); } private void createSequenceSource(PythonOperationInfo info) throws IOException { - sets.put(info.setID, env.generateSequence(info.from, info.to).name("SequenceSource")); + sets.put(info.setID, env.generateSequence(info.from, info.to).name("SequenceSource").map(new SerializerMap()).name("SequenceSourcePostStep")); } private void createCsvSink(PythonOperationInfo info) throws IOException { DataSet parent = (DataSet) sets.get(info.parentID); - parent.writeAsCsv(info.path, info.lineDelimiter, info.fieldDelimiter, info.writeMode).name("CsvSink"); + parent.map(new StringTupleDeserializerMap()).name("CsvSinkPreStep") + .writeAsCsv(info.path, info.lineDelimiter, info.fieldDelimiter, info.writeMode).name("CsvSink"); } private void createTextSink(PythonOperationInfo info) throws IOException { DataSet parent = (DataSet) sets.get(info.parentID); - parent.writeAsText(info.path, info.writeMode).name("TextSink"); + parent.map(new StringDeserializerMap()).writeAsText(info.path, info.writeMode).name("TextSink"); } private void createPrintSink(PythonOperationInfo info) throws IOException { DataSet parent = (DataSet) sets.get(info.parentID); - parent.output(new PrintingOutputFormat(info.toError)); + parent.map(new StringDeserializerMap()).name("PrintSinkPreStep").output(new PrintingOutputFormat(info.toError)); } private void createBroadcastVariable(PythonOperationInfo info) throws IOException { @@ -471,7 +473,7 @@ private void createAggregationOperation(PythonOperationInfo info) throws IOExcep private void createDistinctOperation(PythonOperationInfo info) throws IOException { DataSet op = (DataSet) sets.get(info.parentID); - sets.put(info.setID, info.keys.length == 0 ? op.distinct() : op.distinct(info.keys).name("Distinct")); + sets.put(info.setID, info.keys.length == 0 ? op.distinct() : op.distinct(info.keys).name("Distinct").map(new KeyDiscarder()).name("DistinctPostStep")); } private void createFirstOperation(PythonOperationInfo info) throws IOException { @@ -486,7 +488,7 @@ private void createGroupOperation(PythonOperationInfo info) throws IOException { private void createHashPartitionOperation(PythonOperationInfo info) throws IOException { DataSet op1 = (DataSet) sets.get(info.parentID); - sets.put(info.setID, op1.partitionByHash(info.keys)); + sets.put(info.setID, op1.partitionByHash(info.keys).map(new KeyDiscarder()).name("HashPartitionPostStep")); } private void createProjectOperation(PythonOperationInfo info) throws IOException { @@ -546,23 +548,10 @@ private void createCrossOperation(DatasizeHint mode, PythonOperationInfo info) { default: throw new IllegalArgumentException("Invalid Cross mode specified: " + mode); } - if (info.types != null && (info.projections == null || info.projections.length == 0)) { + if (info.usesUDF) { sets.put(info.setID, defaultResult.mapPartition(new PythonMapPartition(info.setID, info.types)).name(info.name)); - } else if (info.projections.length == 0) { - sets.put(info.setID, defaultResult.name("DefaultCross")); } else { - ProjectCross project = null; - for (ProjectionEntry pe : info.projections) { - switch (pe.side) { - case FIRST: - project = project == null ? defaultResult.projectFirst(pe.keys) : project.projectFirst(pe.keys); - break; - case SECOND: - project = project == null ? defaultResult.projectSecond(pe.keys) : project.projectSecond(pe.keys); - break; - } - } - sets.put(info.setID, project.name("ProjectCross")); + sets.put(info.setID, defaultResult.name("DefaultCross")); } } @@ -616,38 +605,22 @@ private void createJoinOperation(DatasizeHint mode, PythonOperationInfo info) { DataSet op1 = (DataSet) sets.get(info.parentID); DataSet op2 = (DataSet) sets.get(info.otherID); - if (info.types != null && (info.projections == null || info.projections.length == 0)) { - sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode).name("PythonJoinPreStep") + if (info.usesUDF) { + sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode) .mapPartition(new PythonMapPartition(info.setID, info.types)).name(info.name)); } else { - DefaultJoin defaultResult = createDefaultJoin(op1, op2, info.keys1, info.keys2, mode); - if (info.projections.length == 0) { - sets.put(info.setID, defaultResult.name("DefaultJoin")); - } else { - ProjectJoin project = null; - for (ProjectionEntry pe : info.projections) { - switch (pe.side) { - case FIRST: - project = project == null ? defaultResult.projectFirst(pe.keys) : project.projectFirst(pe.keys); - break; - case SECOND: - project = project == null ? defaultResult.projectSecond(pe.keys) : project.projectSecond(pe.keys); - break; - } - } - sets.put(info.setID, project.name("ProjectJoin")); - } + sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode)); } } - private DefaultJoin createDefaultJoin(DataSet op1, DataSet op2, String[] firstKeys, String[] secondKeys, DatasizeHint mode) { + private DataSet createDefaultJoin(DataSet op1, DataSet op2, String[] firstKeys, String[] secondKeys, DatasizeHint mode) { switch (mode) { case NONE: - return op1.join(op2).where(firstKeys).equalTo(secondKeys); + return op1.join(op2).where(firstKeys).equalTo(secondKeys).map(new NestedKeyDiscarder()).name("DefaultJoinPostStep"); case HUGE: - return op1.joinWithHuge(op2).where(firstKeys).equalTo(secondKeys); + return op1.joinWithHuge(op2).where(firstKeys).equalTo(secondKeys).map(new NestedKeyDiscarder()).name("DefaultJoinPostStep"); case TINY: - return op1.joinWithTiny(op2).where(firstKeys).equalTo(secondKeys); + return op1.joinWithTiny(op2).where(firstKeys).equalTo(secondKeys).map(new NestedKeyDiscarder()).name("DefaultJoinPostStep"); default: throw new IllegalArgumentException("Invalid join mode specified."); } diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonCoGroup.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonCoGroup.java index 2349aa919f838..33d88c3182d76 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonCoGroup.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonCoGroup.java @@ -33,7 +33,7 @@ public class PythonCoGroup extends RichCoGroupFunction typeInformation) { this.typeInformation = typeInformation; - streamer = new PythonStreamer(this, id); + streamer = new PythonStreamer(this, id, true); } /** diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonMapPartition.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonMapPartition.java index 50b2cf4667e2c..6282210c4863d 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonMapPartition.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/PythonMapPartition.java @@ -14,6 +14,7 @@ import java.io.IOException; import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.configuration.Configuration; @@ -33,7 +34,7 @@ public class PythonMapPartition extends RichMapPartitionFunction typeInformation) { this.typeInformation = typeInformation; - streamer = new PythonStreamer(this, id); + streamer = new PythonStreamer(this, id, typeInformation instanceof PrimitiveArrayTypeInfo); } /** diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/IdentityGroupReduce.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/IdentityGroupReduce.java similarity index 91% rename from flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/IdentityGroupReduce.java rename to flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/IdentityGroupReduce.java index a4201532526d1..1e7bbe6e260b9 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/IdentityGroupReduce.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/IdentityGroupReduce.java @@ -10,11 +10,14 @@ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ -package org.apache.flink.python.api.functions; +package org.apache.flink.python.api.functions.util; import org.apache.flink.util.Collector; import org.apache.flink.api.common.functions.GroupReduceFunction; +/* +Utility function to group and sort data. +*/ public class IdentityGroupReduce implements GroupReduceFunction { @Override public final void reduce(Iterable values, Collector out) throws Exception { diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/KeyDiscarder.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/KeyDiscarder.java new file mode 100644 index 0000000000000..118bc9f7976a7 --- /dev/null +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/KeyDiscarder.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.apache.flink.python.api.functions.util; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields; +import org.apache.flink.api.java.tuple.Tuple; +import org.apache.flink.api.java.tuple.Tuple2; + +/* +Utility function to extract the value from a Key-Value Tuple. +*/ +@ForwardedFields("f1->*") +public class KeyDiscarder implements MapFunction, byte[]> { + @Override + public byte[] map(Tuple2 value) throws Exception { + return value.f1; + } +} diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/NestedKeyDiscarder.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/NestedKeyDiscarder.java new file mode 100644 index 0000000000000..d59eb73c0b90f --- /dev/null +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/NestedKeyDiscarder.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.apache.flink.python.api.functions.util; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields; +import org.apache.flink.api.java.tuple.Tuple; +import org.apache.flink.api.java.tuple.Tuple2; + +/* +Utility function to extract values from 2 Key-Value Tuples after a DefaultJoin. +*/ +@ForwardedFields("f0.f1->f0; f1.f1->f1") +public class NestedKeyDiscarder implements MapFunction> { + @Override + public Tuple2 map(IN value) throws Exception { + Tuple2, Tuple2> x = (Tuple2, Tuple2>) value; + return new Tuple2<>(x.f0.f1, x.f1.f1); + } +} diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/SerializerMap.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/SerializerMap.java new file mode 100644 index 0000000000000..fba83f9dfe763 --- /dev/null +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/SerializerMap.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.apache.flink.python.api.functions.util; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.python.api.streaming.util.SerializationUtils; +import org.apache.flink.python.api.streaming.util.SerializationUtils.Serializer; + +/* +Utility function to serialize values, usually directly from data sources. +*/ +public class SerializerMap implements MapFunction { + private Serializer serializer = null; + + @Override + public byte[] map(IN value) throws Exception { + if (serializer == null) { + serializer = SerializationUtils.getSerializer(value); + } + return serializer.serialize(value); + } +} diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringDeserializerMap.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringDeserializerMap.java new file mode 100644 index 0000000000000..d89fc41322eb0 --- /dev/null +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringDeserializerMap.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.apache.flink.python.api.functions.util; + +import org.apache.flink.api.common.functions.MapFunction; + +/* +Utility function to deserialize strings, used for non-CSV sinks. +*/ +public class StringDeserializerMap implements MapFunction { + @Override + public String map(byte[] value) throws Exception { + //discard type byte and size + return new String(value, 5, value.length - 5); + } +} diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringTupleDeserializerMap.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringTupleDeserializerMap.java new file mode 100644 index 0000000000000..b6d60e1e7a3b1 --- /dev/null +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/functions/util/StringTupleDeserializerMap.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.apache.flink.python.api.functions.util; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.tuple.Tuple1; + +/* +Utility function to deserialize strings, used for CSV sinks. +*/ +public class StringTupleDeserializerMap implements MapFunction> { + @Override + public Tuple1 map(byte[] value) throws Exception { + //5 = string type byte + string size + return new Tuple1<>(new String(value, 5, value.length - 5)); + } +} diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonReceiver.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonReceiver.java index 9ed047cb6e349..3ee0fde616d19 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonReceiver.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonReceiver.java @@ -20,24 +20,13 @@ import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import org.apache.flink.api.java.tuple.Tuple; +import org.apache.flink.api.java.tuple.Tuple2; import static org.apache.flink.python.api.PythonPlanBinder.FLINK_TMP_DATA_DIR; import static org.apache.flink.python.api.PythonPlanBinder.MAPPED_FILE_SIZE; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BOOLEAN; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTE; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTES; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_DOUBLE; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_FLOAT; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_INTEGER; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_LONG; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_NULL; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_SHORT; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_STRING; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_TUPLE; -import org.apache.flink.python.api.types.CustomTypeWrapper; import org.apache.flink.util.Collector; /** - * General-purpose class to read data from memory-mapped files. + * This class is used to read data from memory-mapped files. */ public class PythonReceiver implements Serializable { private static final long serialVersionUID = -2474088929850009968L; @@ -47,11 +36,18 @@ public class PythonReceiver implements Serializable { private FileChannel inputChannel; private MappedByteBuffer fileBuffer; + private final boolean readAsByteArray; + private Deserializer deserializer = null; + public PythonReceiver(boolean usesByteArray) { + readAsByteArray = usesByteArray; + } + //=====Setup======================================================================================================== public void open(String path) throws IOException { setupMappedFile(path); + deserializer = readAsByteArray ? new ByteArrayDeserializer() : new TupleDeserializer(); } private void setupMappedFile(String inputFilePath) throws FileNotFoundException, IOException { @@ -81,6 +77,8 @@ private void closeMappedFile() throws IOException { inputRAF.close(); } + + //=====IO=========================================================================================================== /** * Reads a buffer of the given size from the memory-mapped file, and collects all records contained. This method * assumes that all values in the buffer are of the same type. This method does NOT take care of synchronization. @@ -94,173 +92,46 @@ private void closeMappedFile() throws IOException { public void collectBuffer(Collector c, int bufferSize) throws IOException { fileBuffer.position(0); - if (deserializer == null) { - byte type = fileBuffer.get(); - deserializer = getDeserializer(type); - } while (fileBuffer.position() < bufferSize) { c.collect(deserializer.deserialize()); } } //=====Deserializer================================================================================================= - private Deserializer getDeserializer(byte type) { - switch (type) { - case TYPE_TUPLE: - return new TupleDeserializer(); - case TYPE_BOOLEAN: - return new BooleanDeserializer(); - case TYPE_BYTE: - return new ByteDeserializer(); - case TYPE_BYTES: - return new BytesDeserializer(); - case TYPE_SHORT: - return new ShortDeserializer(); - case TYPE_INTEGER: - return new IntDeserializer(); - case TYPE_LONG: - return new LongDeserializer(); - case TYPE_STRING: - return new StringDeserializer(); - case TYPE_FLOAT: - return new FloatDeserializer(); - case TYPE_DOUBLE: - return new DoubleDeserializer(); - case TYPE_NULL: - return new NullDeserializer(); - default: - return new CustomTypeDeserializer(type); - - } - } - private interface Deserializer { public T deserialize(); } - private class CustomTypeDeserializer implements Deserializer { - private final byte type; - - public CustomTypeDeserializer(byte type) { - this.type = type; - } - - @Override - public CustomTypeWrapper deserialize() { - int size = fileBuffer.getInt(); - byte[] data = new byte[size]; - fileBuffer.get(data); - return new CustomTypeWrapper(type, data); - } - } - - private class BooleanDeserializer implements Deserializer { - @Override - public Boolean deserialize() { - return fileBuffer.get() == 1; - } - } - - private class ByteDeserializer implements Deserializer { - @Override - public Byte deserialize() { - return fileBuffer.get(); - } - } - - private class ShortDeserializer implements Deserializer { - @Override - public Short deserialize() { - return fileBuffer.getShort(); - } - } - - private class IntDeserializer implements Deserializer { - @Override - public Integer deserialize() { - return fileBuffer.getInt(); - } - } - - private class LongDeserializer implements Deserializer { - @Override - public Long deserialize() { - return fileBuffer.getLong(); - } - } - - private class FloatDeserializer implements Deserializer { - @Override - public Float deserialize() { - return fileBuffer.getFloat(); - } - } - - private class DoubleDeserializer implements Deserializer { - @Override - public Double deserialize() { - return fileBuffer.getDouble(); - } - } - - private class StringDeserializer implements Deserializer { - private int size; - - @Override - public String deserialize() { - size = fileBuffer.getInt(); - byte[] buffer = new byte[size]; - fileBuffer.get(buffer); - return new String(buffer); - } - } - - private class NullDeserializer implements Deserializer { - @Override - public Object deserialize() { - return null; - } - } - - private class BytesDeserializer implements Deserializer { + private class ByteArrayDeserializer implements Deserializer { @Override public byte[] deserialize() { - int length = fileBuffer.getInt(); - byte[] result = new byte[length]; - fileBuffer.get(result); - return result; - } - - } - - private class TupleDeserializer implements Deserializer { - Deserializer[] deserializer = null; - Tuple reuse; - - public TupleDeserializer() { int size = fileBuffer.getInt(); - reuse = createTuple(size); - deserializer = new Deserializer[size]; - for (int x = 0; x < deserializer.length; x++) { - deserializer[x] = getDeserializer(fileBuffer.get()); - } + byte[] value = new byte[size]; + fileBuffer.get(value); + return value; } + } + private class TupleDeserializer implements Deserializer> { @Override - public Tuple deserialize() { - for (int x = 0; x < deserializer.length; x++) { - reuse.setField(deserializer[x].deserialize(), x); + public Tuple2 deserialize() { + int keyTupleSize = fileBuffer.get(); + Tuple keys = createTuple(keyTupleSize); + for (int x = 0; x < keyTupleSize; x++) { + byte[] data = new byte[fileBuffer.getInt()]; + fileBuffer.get(data); + keys.setField(data, x); } - return reuse; + byte[] value = new byte[fileBuffer.getInt()]; + fileBuffer.get(value); + return new Tuple2(keys, value); } } public static Tuple createTuple(int size) { try { return Tuple.getTupleClass(size).newInstance(); - } catch (InstantiationException e) { - throw new RuntimeException(e); - } catch (IllegalAccessException e) { + } catch (InstantiationException | IllegalAccessException e) { throw new RuntimeException(e); } } diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonSender.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonSender.java index 1d17243c0ac7e..3cd5f4db16d36 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonSender.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonSender.java @@ -20,30 +20,19 @@ import java.nio.ByteBuffer; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; -import java.util.ArrayList; import java.util.Iterator; -import java.util.List; import org.apache.flink.api.java.tuple.Tuple; +import org.apache.flink.api.java.tuple.Tuple2; import static org.apache.flink.python.api.PythonPlanBinder.FLINK_TMP_DATA_DIR; import static org.apache.flink.python.api.PythonPlanBinder.MAPPED_FILE_SIZE; -import org.apache.flink.python.api.types.CustomTypeWrapper; /** * General-purpose class to write data to memory-mapped files. */ -public class PythonSender implements Serializable { - public static final byte TYPE_TUPLE = (byte) 11; - public static final byte TYPE_BOOLEAN = (byte) 10; - public static final byte TYPE_BYTE = (byte) 9; - public static final byte TYPE_SHORT = (byte) 8; - public static final byte TYPE_INTEGER = (byte) 7; - public static final byte TYPE_LONG = (byte) 6; - public static final byte TYPE_DOUBLE = (byte) 4; - public static final byte TYPE_FLOAT = (byte) 5; - public static final byte TYPE_CHAR = (byte) 3; - public static final byte TYPE_STRING = (byte) 2; - public static final byte TYPE_BYTES = (byte) 1; - public static final byte TYPE_NULL = (byte) 0; +public class PythonSender implements Serializable { + public static final byte TYPE_ARRAY = (byte) 63; + public static final byte TYPE_KEY_VALUE = (byte) 62; + public static final byte TYPE_VALUE_VALUE = (byte) 61; private File outputFile; private RandomAccessFile outputRAF; @@ -95,7 +84,7 @@ public void reset() { fileBuffer.clear(); } - //=====Serialization================================================================================================ + //=====IO=========================================================================================================== /** * Writes a single record to the memory-mapped file. This method does NOT take care of synchronization. The user * must guarantee that the file may be written to before calling this method. This method essentially reserves the @@ -173,75 +162,24 @@ public int sendBuffer(Iterator i, int group) throws IOException { return size; } - private enum SupportedTypes { - TUPLE, BOOLEAN, BYTE, BYTES, CHARACTER, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, OTHER, NULL, CUSTOMTYPEWRAPPER - } - //=====Serializer=================================================================================================== - private Serializer getSerializer(Object value) throws IOException { - String className = value.getClass().getSimpleName().toUpperCase(); - if (className.startsWith("TUPLE")) { - className = "TUPLE"; + private Serializer getSerializer(Object value) { + if (value instanceof byte[]) { + return new ArraySerializer(); } - if (className.startsWith("BYTE[]")) { - className = "BYTES"; + if (((Tuple2) value).f0 instanceof byte[]) { + return new ValuePairSerializer(); } - SupportedTypes type = SupportedTypes.valueOf(className); - switch (type) { - case TUPLE: - fileBuffer.put(TYPE_TUPLE); - fileBuffer.putInt(((Tuple) value).getArity()); - return new TupleSerializer((Tuple) value); - case BOOLEAN: - fileBuffer.put(TYPE_BOOLEAN); - return new BooleanSerializer(); - case BYTE: - fileBuffer.put(TYPE_BYTE); - return new ByteSerializer(); - case BYTES: - fileBuffer.put(TYPE_BYTES); - return new BytesSerializer(); - case CHARACTER: - fileBuffer.put(TYPE_CHAR); - return new CharSerializer(); - case SHORT: - fileBuffer.put(TYPE_SHORT); - return new ShortSerializer(); - case INTEGER: - fileBuffer.put(TYPE_INTEGER); - return new IntSerializer(); - case LONG: - fileBuffer.put(TYPE_LONG); - return new LongSerializer(); - case STRING: - fileBuffer.put(TYPE_STRING); - return new StringSerializer(); - case FLOAT: - fileBuffer.put(TYPE_FLOAT); - return new FloatSerializer(); - case DOUBLE: - fileBuffer.put(TYPE_DOUBLE); - return new DoubleSerializer(); - case NULL: - fileBuffer.put(TYPE_NULL); - return new NullSerializer(); - case CUSTOMTYPEWRAPPER: - fileBuffer.put(((CustomTypeWrapper) value).getType()); - return new CustomTypeSerializer(); - default: - throw new IllegalArgumentException("Unknown Type encountered: " + type); + if (((Tuple2) value).f0 instanceof Tuple) { + return new KeyValuePairSerializer(); } + throw new IllegalArgumentException("This object can't be serialized: " + value.toString()); } private abstract class Serializer { protected ByteBuffer buffer; - public Serializer(int capacity) { - buffer = ByteBuffer.allocate(capacity); - } - public ByteBuffer serialize(T value) { - buffer.clear(); serializeInternal(value); buffer.flip(); return buffer; @@ -250,171 +188,39 @@ public ByteBuffer serialize(T value) { public abstract void serializeInternal(T value); } - private class CustomTypeSerializer extends Serializer { - public CustomTypeSerializer() { - super(0); - } + private class ArraySerializer extends Serializer { @Override - public void serializeInternal(CustomTypeWrapper value) { - byte[] bytes = value.getData(); - buffer = ByteBuffer.wrap(bytes); - buffer.position(bytes.length); - } - } - - private class ByteSerializer extends Serializer { - public ByteSerializer() { - super(1); - } - - @Override - public void serializeInternal(Byte value) { + public void serializeInternal(byte[] value) { + buffer = ByteBuffer.allocate(value.length + 1); + buffer.put(TYPE_ARRAY); buffer.put(value); } } - private class BooleanSerializer extends Serializer { - public BooleanSerializer() { - super(1); - } - - @Override - public void serializeInternal(Boolean value) { - buffer.put(value ? (byte) 1 : (byte) 0); - } - } - - private class CharSerializer extends Serializer { - public CharSerializer() { - super(4); - } - - @Override - public void serializeInternal(Character value) { - buffer.put((value + "").getBytes()); - } - } - - private class ShortSerializer extends Serializer { - public ShortSerializer() { - super(2); - } - - @Override - public void serializeInternal(Short value) { - buffer.putShort(value); - } - } - - private class IntSerializer extends Serializer { - public IntSerializer() { - super(4); - } - - @Override - public void serializeInternal(Integer value) { - buffer.putInt(value); - } - } - - private class LongSerializer extends Serializer { - public LongSerializer() { - super(8); - } - - @Override - public void serializeInternal(Long value) { - buffer.putLong(value); - } - } - - private class StringSerializer extends Serializer { - public StringSerializer() { - super(0); - } - - @Override - public void serializeInternal(String value) { - byte[] bytes = value.getBytes(); - buffer = ByteBuffer.allocate(bytes.length + 4); - buffer.putInt(bytes.length); - buffer.put(bytes); - } - } - - private class FloatSerializer extends Serializer { - public FloatSerializer() { - super(4); - } - - @Override - public void serializeInternal(Float value) { - buffer.putFloat(value); - } - } - - private class DoubleSerializer extends Serializer { - public DoubleSerializer() { - super(8); - } - + private class ValuePairSerializer extends Serializer> { @Override - public void serializeInternal(Double value) { - buffer.putDouble(value); + public void serializeInternal(Tuple2 value) { + buffer = ByteBuffer.allocate(1 + value.f0.length + value.f1.length); + buffer.put(TYPE_VALUE_VALUE); + buffer.put(value.f0); + buffer.put(value.f1); } } - private class NullSerializer extends Serializer { - public NullSerializer() { - super(0); - } - - @Override - public void serializeInternal(Object value) { - } - } - - private class BytesSerializer extends Serializer { - public BytesSerializer() { - super(0); - } - - @Override - public void serializeInternal(byte[] value) { - buffer = ByteBuffer.allocate(4 + value.length); - buffer.putInt(value.length); - buffer.put(value); - } - } - - private class TupleSerializer extends Serializer { - private final Serializer[] serializer; - private final List buffers; - - public TupleSerializer(Tuple value) throws IOException { - super(0); - serializer = new Serializer[value.getArity()]; - buffers = new ArrayList(); - for (int x = 0; x < serializer.length; x++) { - serializer[x] = getSerializer(value.getField(x)); - } - } - + private class KeyValuePairSerializer extends Serializer> { @Override - public void serializeInternal(Tuple value) { - int length = 0; - for (int x = 0; x < serializer.length; x++) { - serializer[x].buffer.clear(); - serializer[x].serializeInternal(value.getField(x)); - length += serializer[x].buffer.position(); - buffers.add(serializer[x].buffer); + public void serializeInternal(Tuple2 value) { + int keySize = 0; + for (int x = 0; x < value.f0.getArity(); x++) { + keySize += ((byte[]) value.f0.getField(x)).length; } - buffer = ByteBuffer.allocate(length); - for (ByteBuffer b : buffers) { - b.flip(); - buffer.put(b); + buffer = ByteBuffer.allocate(5 + keySize + value.f1.length); + buffer.put(TYPE_KEY_VALUE); + buffer.put((byte) value.f0.getArity()); + for (int x = 0; x < value.f0.getArity(); x++) { + buffer.put((byte[]) value.f0.getField(x)); } - buffers.clear(); + buffer.put(value.f1); } } } diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonStreamer.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonStreamer.java index 1e369629c09fe..556805178618a 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonStreamer.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/data/PythonStreamer.java @@ -34,6 +34,8 @@ import static org.apache.flink.python.api.PythonPlanBinder.FLINK_TMP_DATA_DIR; import static org.apache.flink.python.api.PythonPlanBinder.PLANBINDER_CONFIG_BCVAR_COUNT; import static org.apache.flink.python.api.PythonPlanBinder.PLANBINDER_CONFIG_BCVAR_NAME_PREFIX; +import org.apache.flink.python.api.streaming.util.SerializationUtils.IntSerializer; +import org.apache.flink.python.api.streaming.util.SerializationUtils.StringSerializer; import org.apache.flink.util.Collector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,8 +60,6 @@ public class PythonStreamer implements Serializable { private String inputFilePath; private String outputFilePath; - private final byte[] buffer = new byte[4]; - private Process process; private Thread shutdownThread; protected ServerSocket server; @@ -75,13 +75,13 @@ public class PythonStreamer implements Serializable { protected final AbstractRichFunction function; - public PythonStreamer(AbstractRichFunction function, int id) { + public PythonStreamer(AbstractRichFunction function, int id, boolean usesByteArray) { this.id = id; this.usePython3 = PythonPlanBinder.usePython3; this.debug = DEBUG; planArguments = PythonPlanBinder.arguments.toString(); sender = new PythonSender(); - receiver = new PythonReceiver(); + receiver = new PythonReceiver(usesByteArray); this.function = function; } @@ -167,9 +167,9 @@ public void run() { */ public void close() throws IOException { try { - socket.close(); - sender.close(); - receiver.close(); + socket.close(); + sender.close(); + receiver.close(); } catch (Exception e) { LOG.error("Exception occurred while closing Streamer. :" + e.getMessage()); } @@ -202,31 +202,18 @@ private void destroyProcess() throws IOException { } } } - - private void sendWriteNotification(int size, boolean hasNext) throws IOException { - byte[] tmp = new byte[5]; - putInt(tmp, 0, size); - tmp[4] = hasNext ? 0 : SIGNAL_LAST; - out.write(tmp, 0, 5); + + private void sendWriteNotification(int size, boolean hasNext) throws IOException { + out.writeInt(size); + out.writeByte(hasNext ? 0 : SIGNAL_LAST); out.flush(); } private void sendReadConfirmation() throws IOException { - out.write(new byte[1], 0, 1); + out.writeByte(1); out.flush(); } - private void checkForError() { - if (getInt(buffer, 0) == -2) { - try { //wait before terminating to ensure that the complete error message is printed - Thread.sleep(2000); - } catch (InterruptedException ex) { - } - throw new RuntimeException( - "External process for task " + function.getRuntimeContext().getTaskName() + " terminated prematurely." + msg); - } - } - /** * Sends all broadcast-variables encoded in the configuration to the external process. * @@ -243,26 +230,19 @@ public final void sendBroadCastVariables(Configuration config) throws IOExceptio names[x] = config.getString(PLANBINDER_CONFIG_BCVAR_NAME_PREFIX + x, null); } - in.readFully(buffer, 0, 4); - checkForError(); - int size = sender.sendRecord(broadcastCount); - sendWriteNotification(size, false); + out.write(new IntSerializer().serializeWithoutTypeInfo(broadcastCount)); + StringSerializer stringSerializer = new StringSerializer(); for (String name : names) { Iterator bcv = function.getRuntimeContext().getBroadcastVariable(name).iterator(); - in.readFully(buffer, 0, 4); - checkForError(); - size = sender.sendRecord(name); - sendWriteNotification(size, false); + out.write(stringSerializer.serializeWithoutTypeInfo(name)); - while (bcv.hasNext() || sender.hasRemaining(0)) { - in.readFully(buffer, 0, 4); - checkForError(); - size = sender.sendBuffer(bcv, 0); - sendWriteNotification(size, bcv.hasNext() || sender.hasRemaining(0)); + while (bcv.hasNext()) { + out.writeByte(1); + out.write((byte[]) bcv.next()); } - sender.reset(); + out.writeByte(0); } } catch (SocketTimeoutException ste) { throw new RuntimeException("External process for task " + function.getRuntimeContext().getTaskName() + " stopped responding." + msg); @@ -281,8 +261,7 @@ public final void streamBufferWithoutGroups(Iterator i, Collector c) throws IOEx int size; if (i.hasNext()) { while (true) { - in.readFully(buffer, 0, 4); - int sig = getInt(buffer, 0); + int sig = in.readInt(); switch (sig) { case SIGNAL_BUFFER_REQUEST: if (i.hasNext() || sender.hasRemaining(0)) { @@ -326,8 +305,7 @@ public final void streamBufferWithGroups(Iterator i1, Iterator i2, Collector c) int size; if (i1.hasNext() || i2.hasNext()) { while (true) { - in.readFully(buffer, 0, 4); - int sig = getInt(buffer, 0); + int sig = in.readInt(); switch (sig) { case SIGNAL_BUFFER_REQUEST_G0: if (i1.hasNext() || sender.hasRemaining(0)) { @@ -361,16 +339,4 @@ public final void streamBufferWithGroups(Iterator i1, Iterator i2, Collector c) throw new RuntimeException("External process for task " + function.getRuntimeContext().getTaskName() + " stopped responding." + msg); } } - - protected final static int getInt(byte[] array, int offset) { - return (array[offset] << 24) | (array[offset + 1] & 0xff) << 16 | (array[offset + 2] & 0xff) << 8 | (array[offset + 3] & 0xff); - } - - protected final static void putInt(byte[] array, int offset, int value) { - array[offset] = (byte) (value >> 24); - array[offset + 1] = (byte) (value >> 16); - array[offset + 2] = (byte) (value >> 8); - array[offset + 3] = (byte) (value); - } - } diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanReceiver.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanReceiver.java index ed02ce4e4468d..5095e3bd32b3a 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanReceiver.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanReceiver.java @@ -18,17 +18,15 @@ import java.io.Serializable; import org.apache.flink.api.java.tuple.Tuple; import static org.apache.flink.python.api.streaming.data.PythonReceiver.createTuple; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BOOLEAN; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTE; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTES; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_DOUBLE; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_FLOAT; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_INTEGER; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_LONG; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_NULL; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_SHORT; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_STRING; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_TUPLE; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_BOOLEAN; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_BYTE; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_BYTES; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_DOUBLE; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_FLOAT; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_INTEGER; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_LONG; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_NULL; +import static org.apache.flink.python.api.streaming.util.SerializationUtils.TYPE_STRING; import org.apache.flink.python.api.types.CustomTypeWrapper; /** @@ -46,62 +44,157 @@ public Object getRecord() throws IOException { } public Object getRecord(boolean normalized) throws IOException { - return receiveField(normalized); + return getDeserializer().deserialize(normalized); } - private Object receiveField(boolean normalized) throws IOException { + private Deserializer getDeserializer() throws IOException { byte type = (byte) input.readByte(); - switch (type) { - case TYPE_TUPLE: - int tupleSize = input.readByte(); - Tuple tuple = createTuple(tupleSize); - for (int x = 0; x < tupleSize; x++) { - tuple.setField(receiveField(normalized), x); + if (type > 0 && type < 26) { + Deserializer[] d = new Deserializer[type]; + for (int x = 0; x < d.length; x++) { + d[x] = getDeserializer(); } - return tuple; + return new TupleDeserializer(d); + } + switch (type) { case TYPE_BOOLEAN: - return input.readByte() == 1; + return new BooleanDeserializer(); case TYPE_BYTE: - return (byte) input.readByte(); - case TYPE_SHORT: - if (normalized) { - return (int) input.readShort(); - } else { - return input.readShort(); - } + return new ByteDeserializer(); case TYPE_INTEGER: - return input.readInt(); + return new IntDeserializer(); case TYPE_LONG: - if (normalized) { - return new Long(input.readLong()).intValue(); - } else { - return input.readLong(); - } + return new LongDeserializer(); case TYPE_FLOAT: - if (normalized) { - return (double) input.readFloat(); - } else { - return input.readFloat(); - } + return new FloatDeserializer(); case TYPE_DOUBLE: - return input.readDouble(); + return new DoubleDeserializer(); case TYPE_STRING: - int stringSize = input.readInt(); - byte[] string = new byte[stringSize]; - input.readFully(string); - return new String(string); + return new StringDeserializer(); case TYPE_BYTES: - int bytessize = input.readInt(); - byte[] bytes = new byte[bytessize]; - input.readFully(bytes); - return bytes; + return new BytesDeserializer(); case TYPE_NULL: - return null; + return new NullDeserializer(); default: - int size = input.readInt(); - byte[] data = new byte[size]; - input.readFully(data); - return new CustomTypeWrapper(type, data); + return new CustomTypeDeserializer(type); + } + } + + private abstract class Deserializer { + public T deserialize() throws IOException { + return deserialize(false); + } + + public abstract T deserialize(boolean normalized) throws IOException; + } + + private class TupleDeserializer extends Deserializer { + Deserializer[] deserializer; + + public TupleDeserializer(Deserializer[] deserializer) { + this.deserializer = deserializer; + } + + @Override + public Tuple deserialize(boolean normalized) throws IOException { + Tuple result = createTuple(deserializer.length); + for (int x = 0; x < result.getArity(); x++) { + result.setField(deserializer[x].deserialize(normalized), x); + } + return result; + } + } + + private class CustomTypeDeserializer extends Deserializer { + private final byte type; + + public CustomTypeDeserializer(byte type) { + this.type = type; + } + + @Override + public CustomTypeWrapper deserialize(boolean normalized) throws IOException { + int size = input.readInt(); + byte[] data = new byte[size]; + input.read(data); + return new CustomTypeWrapper(type, data); + } + } + + private class BooleanDeserializer extends Deserializer { + @Override + public Boolean deserialize(boolean normalized) throws IOException { + return input.readBoolean(); + } + } + + private class ByteDeserializer extends Deserializer { + @Override + public Byte deserialize(boolean normalized) throws IOException { + return input.readByte(); + } + } + + private class IntDeserializer extends Deserializer { + @Override + public Integer deserialize(boolean normalized) throws IOException { + return input.readInt(); + } + } + + private class LongDeserializer extends Deserializer { + @Override + public Object deserialize(boolean normalized) throws IOException { + if (normalized) { + return new Long(input.readLong()).intValue(); + } else { + return input.readLong(); + } + } + } + + private class FloatDeserializer extends Deserializer { + @Override + public Object deserialize(boolean normalized) throws IOException { + if (normalized) { + return (double) input.readFloat(); + } else { + return input.readFloat(); + } + } + } + + private class DoubleDeserializer extends Deserializer { + @Override + public Double deserialize(boolean normalized) throws IOException { + return input.readDouble(); + } + } + + private class StringDeserializer extends Deserializer { + @Override + public String deserialize(boolean normalized) throws IOException { + int size = input.readInt(); + byte[] buffer = new byte[size]; + input.read(buffer); + return new String(buffer); + } + } + + private class NullDeserializer extends Deserializer { + @Override + public Object deserialize(boolean normalized) throws IOException { + return null; + } + } + + private class BytesDeserializer extends Deserializer { + @Override + public byte[] deserialize(boolean normalized) throws IOException { + int size = input.readInt(); + byte[] buffer = new byte[size]; + input.read(buffer); + return buffer; } } } diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanSender.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanSender.java index 16a1eba8d5a09..7b6b68ad28bd0 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanSender.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/plan/PythonPlanSender.java @@ -16,19 +16,7 @@ import java.io.IOException; import java.io.OutputStream; import java.io.Serializable; -import org.apache.flink.api.java.tuple.Tuple; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BOOLEAN; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTE; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_BYTES; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_DOUBLE; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_FLOAT; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_INTEGER; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_LONG; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_NULL; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_SHORT; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_STRING; -import static org.apache.flink.python.api.streaming.data.PythonSender.TYPE_TUPLE; -import org.apache.flink.python.api.types.CustomTypeWrapper; +import org.apache.flink.python.api.streaming.util.SerializationUtils; /** * Instances of this class can be used to send data to the plan process. @@ -41,76 +29,7 @@ public PythonPlanSender(OutputStream output) { } public void sendRecord(Object record) throws IOException { - String className = record.getClass().getSimpleName().toUpperCase(); - if (className.startsWith("TUPLE")) { - className = "TUPLE"; - } - if (className.startsWith("BYTE[]")) { - className = "BYTES"; - } - SupportedTypes type = SupportedTypes.valueOf(className); - switch (type) { - case TUPLE: - output.write(TYPE_TUPLE); - int arity = ((Tuple) record).getArity(); - output.writeInt(arity); - for (int x = 0; x < arity; x++) { - sendRecord(((Tuple) record).getField(x)); - } - return; - case BOOLEAN: - output.write(TYPE_BOOLEAN); - output.write(((Boolean) record) ? (byte) 1 : (byte) 0); - return; - case BYTE: - output.write(TYPE_BYTE); - output.write((Byte) record); - return; - case BYTES: - output.write(TYPE_BYTES); - output.write((byte[]) record, 0, ((byte[]) record).length); - return; - case CHARACTER: - output.write(TYPE_STRING); - output.writeChars(((Character) record) + ""); - return; - case SHORT: - output.write(TYPE_SHORT); - output.writeShort((Short) record); - return; - case INTEGER: - output.write(TYPE_INTEGER); - output.writeInt((Integer) record); - return; - case LONG: - output.write(TYPE_LONG); - output.writeLong((Long) record); - return; - case STRING: - output.write(TYPE_STRING); - output.writeBytes((String) record); - return; - case FLOAT: - output.write(TYPE_FLOAT); - output.writeFloat((Float) record); - return; - case DOUBLE: - output.write(TYPE_DOUBLE); - output.writeDouble((Double) record); - return; - case NULL: - output.write(TYPE_NULL); - return; - case CUSTOMTYPEWRAPPER: - output.write(((CustomTypeWrapper) record).getType()); - output.write(((CustomTypeWrapper) record).getData()); - return; - default: - throw new IllegalArgumentException("Unknown Type encountered: " + type); - } - } - - private enum SupportedTypes { - TUPLE, BOOLEAN, BYTE, BYTES, CHARACTER, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, OTHER, NULL, CUSTOMTYPEWRAPPER + byte[] data = SerializationUtils.getSerializer(record).serialize(record); + output.write(data); } } \ No newline at end of file diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/util/SerializationUtils.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/util/SerializationUtils.java new file mode 100644 index 0000000000000..6c83a61c7b7f3 --- /dev/null +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/util/SerializationUtils.java @@ -0,0 +1,283 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.apache.flink.python.api.streaming.util; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import org.apache.flink.api.java.tuple.Tuple; +import org.apache.flink.python.api.types.CustomTypeWrapper; + +public class SerializationUtils { + public static final byte TYPE_BOOLEAN = (byte) 34; + public static final byte TYPE_BYTE = (byte) 33; + public static final byte TYPE_INTEGER = (byte) 32; + public static final byte TYPE_LONG = (byte) 31; + public static final byte TYPE_DOUBLE = (byte) 30; + public static final byte TYPE_FLOAT = (byte) 29; + public static final byte TYPE_STRING = (byte) 28; + public static final byte TYPE_BYTES = (byte) 27; + public static final byte TYPE_NULL = (byte) 26; + + private enum SupportedTypes { + TUPLE, BOOLEAN, BYTE, BYTES, INTEGER, LONG, FLOAT, DOUBLE, STRING, NULL, CUSTOMTYPEWRAPPER + } + + public static Serializer getSerializer(Object value) { + String className = value.getClass().getSimpleName().toUpperCase(); + if (className.startsWith("TUPLE")) { + className = "TUPLE"; + } + if (className.startsWith("BYTE[]")) { + className = "BYTES"; + } + SupportedTypes type = SupportedTypes.valueOf(className); + switch (type) { + case TUPLE: + return new TupleSerializer((Tuple) value); + case BOOLEAN: + return new BooleanSerializer(); + case BYTE: + return new ByteSerializer(); + case BYTES: + return new BytesSerializer(); + case INTEGER: + return new IntSerializer(); + case LONG: + return new LongSerializer(); + case STRING: + return new StringSerializer(); + case FLOAT: + return new FloatSerializer(); + case DOUBLE: + return new DoubleSerializer(); + case NULL: + return new NullSerializer(); + case CUSTOMTYPEWRAPPER: + return new CustomTypeWrapperSerializer((CustomTypeWrapper) value); + default: + throw new IllegalArgumentException("Unsupported Type encountered: " + type); + } + } + + public static abstract class Serializer { + private byte[] typeInfo = null; + + public byte[] serialize(IN value) { + if (typeInfo == null) { + typeInfo = new byte[getTypeInfoSize()]; + ByteBuffer typeBuffer = ByteBuffer.wrap(typeInfo); + putTypeInfo(typeBuffer); + } + byte[] bytes = serializeWithoutTypeInfo(value); + byte[] total = new byte[typeInfo.length + bytes.length]; + ByteBuffer.wrap(total).put(typeInfo).put(bytes); + return total; + } + + public abstract byte[] serializeWithoutTypeInfo(IN value); + + protected abstract void putTypeInfo(ByteBuffer buffer); + + protected int getTypeInfoSize() { + return 1; + } + } + + public static class CustomTypeWrapperSerializer extends Serializer { + private final byte type; + + public CustomTypeWrapperSerializer(CustomTypeWrapper value) { + this.type = value.getType(); + } + + @Override + public byte[] serializeWithoutTypeInfo(CustomTypeWrapper value) { + byte[] result = new byte[4 + value.getData().length]; + ByteBuffer.wrap(result).putInt(value.getData().length).put(value.getData()); + return result; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(type); + } + } + + public static class ByteSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(Byte value) { + return new byte[]{value}; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_BYTE); + } + } + + public static class BooleanSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(Boolean value) { + return new byte[]{value ? (byte) 1 : (byte) 0}; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_BOOLEAN); + } + } + + public static class IntSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(Integer value) { + byte[] data = new byte[4]; + ByteBuffer.wrap(data).putInt(value); + return data; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_INTEGER); + } + } + + public static class LongSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(Long value) { + byte[] data = new byte[8]; + ByteBuffer.wrap(data).putLong(value); + return data; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_LONG); + } + } + + public static class StringSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(String value) { + byte[] string = value.getBytes(); + byte[] data = new byte[4 + string.length]; + ByteBuffer.wrap(data).putInt(string.length).put(string); + return data; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_STRING); + } + } + + public static class FloatSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(Float value) { + byte[] data = new byte[4]; + ByteBuffer.wrap(data).putFloat(value); + return data; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_FLOAT); + } + } + + public static class DoubleSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(Double value) { + byte[] data = new byte[8]; + ByteBuffer.wrap(data).putDouble(value); + return data; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_DOUBLE); + } + } + + public static class NullSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(Object value) { + return new byte[0]; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_NULL); + } + } + + public static class BytesSerializer extends Serializer { + @Override + public byte[] serializeWithoutTypeInfo(byte[] value) { + byte[] data = new byte[4 + value.length]; + ByteBuffer.wrap(data).putInt(value.length).put(value); + return data; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put(TYPE_BYTES); + } + } + + public static class TupleSerializer extends Serializer { + private final Serializer[] serializer; + + public TupleSerializer(Tuple value) { + serializer = new Serializer[value.getArity()]; + for (int x = 0; x < serializer.length; x++) { + serializer[x] = getSerializer(value.getField(x)); + } + } + + @Override + public byte[] serializeWithoutTypeInfo(Tuple value) { + ArrayList bits = new ArrayList(); + + int totalSize = 0; + for (int x = 0; x < serializer.length; x++) { + byte[] bit = serializer[x].serializeWithoutTypeInfo(value.getField(x)); + bits.add(bit); + totalSize += bit.length; + } + int pointer = 0; + byte[] data = new byte[totalSize]; + for (byte[] bit : bits) { + System.arraycopy(bit, 0, data, pointer, bit.length); + pointer += bit.length; + } + return data; + } + + @Override + public void putTypeInfo(ByteBuffer buffer) { + buffer.put((byte) serializer.length); + for (Serializer s : serializer) { + s.putTypeInfo(buffer); + } + } + + @Override + public int getTypeInfoSize() { + int size = 1; + for (Serializer s : serializer) { + size += s.getTypeInfoSize(); + } + return size; + } + } +} diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py index b5674b98cb7f8..4c0db328ccb25 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py @@ -16,11 +16,11 @@ # limitations under the License. ################################################################################ from struct import pack -import sys - from flink.connection.Constants import Types -from flink.plan.Constants import _Dummy + +#=====Compatibility==================================================================================================== +import sys PY2 = sys.version_info[0] == 2 PY3 = sys.version_info[0] == 3 @@ -30,57 +30,128 @@ stringtype = str +#=====Collector======================================================================================================== class Collector(object): - def __init__(self, con, env): + def __init__(self, con, env, info): self._connection = con self._serializer = None self._env = env + self._as_array = isinstance(info.types, bytearray) def _close(self): self._connection.send_end_signal() def collect(self, value): - self._serializer = _get_serializer(self._connection.write, value, self._env._types) + self._serializer = ArraySerializer(value, self._env._types) if self._as_array else KeyValuePairSerializer(value, self._env._types) self.collect = self._collect self.collect(value) def _collect(self, value): - self._connection.write(self._serializer.serialize(value)) + serialized_value = self._serializer.serialize(value) + self._connection.write(serialized_value) + + +class PlanCollector(object): + def __init__(self, con, env): + self._connection = con + self._env = env + + def _close(self): + self._connection.send_end_signal() + + def collect(self, value): + type = _get_type_info(value, self._env._types) + serializer = _get_serializer(value, self._env._types) + self._connection.write(b"".join([type, serializer.serialize(value)])) + + +#=====Serializer======================================================================================================= +class Serializer(object): + def serialize(self, value): + pass + + +class KeyValuePairSerializer(Serializer): + def __init__(self, value, custom_types): + self._typeK = [_get_type_info(key, custom_types) for key in value[0]] + self._typeV = _get_type_info(value[1], custom_types) + self._typeK_length = [len(type) for type in self._typeK] + self._typeV_length = len(self._typeV) + self._serializerK = [_get_serializer(key, custom_types) for key in value[0]] + self._serializerV = _get_serializer(value[1], custom_types) + + def serialize(self, value): + bits = [pack(">i", len(value[0]))[3:4]] + for i in range(len(value[0])): + x = self._serializerK[i].serialize(value[0][i]) + bits.append(pack(">i", len(x) + self._typeK_length[i])) + bits.append(self._typeK[i]) + bits.append(x) + v = self._serializerV.serialize(value[1]) + bits.append(pack(">i", len(v) + self._typeV_length)) + bits.append(self._typeV) + bits.append(v) + return b"".join(bits) + +class ArraySerializer(Serializer): + def __init__(self, value, custom_types): + self._type = _get_type_info(value, custom_types) + self._type_length = len(self._type) + self._serializer = _get_serializer(value, custom_types) -def _get_serializer(write, value, custom_types): + def serialize(self, value): + serialized_value = self._serializer.serialize(value) + return b"".join([pack(">i", len(serialized_value) + self._type_length), self._type, serialized_value]) + + +def _get_type_info(value, custom_types): if isinstance(value, (list, tuple)): - write(Types.TYPE_TUPLE) - write(pack(">I", len(value))) - return TupleSerializer(write, value, custom_types) + return b"".join([pack(">i", len(value))[3:4], b"".join([_get_type_info(field, custom_types) for field in value])]) + elif value is None: + return Types.TYPE_NULL + elif isinstance(value, stringtype): + return Types.TYPE_STRING + elif isinstance(value, bool): + return Types.TYPE_BOOLEAN + elif isinstance(value, int) or PY2 and isinstance(value, long): + return Types.TYPE_LONG + elif isinstance(value, bytearray): + return Types.TYPE_BYTES + elif isinstance(value, float): + return Types.TYPE_DOUBLE + else: + for entry in custom_types: + if isinstance(value, entry[1]): + return entry[0] + raise Exception("Unsupported Type encountered.") + + +def _get_serializer(value, custom_types): + if isinstance(value, (list, tuple)): + return TupleSerializer(value, custom_types) elif value is None: - write(Types.TYPE_NULL) return NullSerializer() elif isinstance(value, stringtype): - write(Types.TYPE_STRING) return StringSerializer() elif isinstance(value, bool): - write(Types.TYPE_BOOLEAN) return BooleanSerializer() elif isinstance(value, int) or PY2 and isinstance(value, long): - write(Types.TYPE_LONG) return LongSerializer() elif isinstance(value, bytearray): - write(Types.TYPE_BYTES) return ByteArraySerializer() elif isinstance(value, float): - write(Types.TYPE_DOUBLE) return FloatSerializer() else: for entry in custom_types: if isinstance(value, entry[1]): - write(entry[0]) - return CustomTypeSerializer(entry[2]) + return CustomTypeSerializer(entry[0], entry[2]) raise Exception("Unsupported Type encountered.") -class CustomTypeSerializer(object): - def __init__(self, serializer): +class CustomTypeSerializer(Serializer): + def __init__(self, id, serializer): + self._id = id self._serializer = serializer def serialize(self, value): @@ -88,9 +159,9 @@ def serialize(self, value): return b"".join([pack(">i",len(msg)), msg]) -class TupleSerializer(object): - def __init__(self, write, value, custom_types): - self.serializer = [_get_serializer(write, field, custom_types) for field in value] +class TupleSerializer(Serializer): + def __init__(self, value, custom_types): + self.serializer = [_get_serializer(field, custom_types) for field in value] def serialize(self, value): bits = [] @@ -99,83 +170,33 @@ def serialize(self, value): return b"".join(bits) -class BooleanSerializer(object): +class BooleanSerializer(Serializer): def serialize(self, value): return pack(">?", value) -class FloatSerializer(object): +class FloatSerializer(Serializer): def serialize(self, value): return pack(">d", value) -class LongSerializer(object): +class LongSerializer(Serializer): def serialize(self, value): return pack(">q", value) -class ByteArraySerializer(object): +class ByteArraySerializer(Serializer): def serialize(self, value): value = bytes(value) return pack(">I", len(value)) + value -class StringSerializer(object): +class StringSerializer(Serializer): def serialize(self, value): value = value.encode("utf-8") return pack(">I", len(value)) + value -class NullSerializer(object): +class NullSerializer(Serializer): def serialize(self, value): - return b"" - - -class TypedCollector(object): - def __init__(self, con, env): - self._connection = con - self._env = env - - def collect(self, value): - if not isinstance(value, (list, tuple)): - self._send_field(value) - else: - self._connection.write(Types.TYPE_TUPLE) - meta = pack(">I", len(value)) - self._connection.write(bytes([meta[3]]) if PY3 else meta[3]) - for field in value: - self.collect(field) - - def _send_field(self, value): - if value is None: - self._connection.write(Types.TYPE_NULL) - elif isinstance(value, stringtype): - value = value.encode("utf-8") - size = pack(">I", len(value)) - self._connection.write(b"".join([Types.TYPE_STRING, size, value])) - elif isinstance(value, bytes): - size = pack(">I", len(value)) - self._connection.write(b"".join([Types.TYPE_BYTES, size, value])) - elif isinstance(value, bool): - data = pack(">?", value) - self._connection.write(b"".join([Types.TYPE_BOOLEAN, data])) - elif isinstance(value, int) or PY2 and isinstance(value, long): - data = pack(">q", value) - self._connection.write(b"".join([Types.TYPE_LONG, data])) - elif isinstance(value, float): - data = pack(">d", value) - self._connection.write(b"".join([Types.TYPE_DOUBLE, data])) - elif isinstance(value, bytearray): - value = bytes(value) - size = pack(">I", len(value)) - self._connection.write(b"".join([Types.TYPE_BYTES, size, value])) - elif isinstance(value, _Dummy): - self._connection.write(pack(">i", 127)[3:]) - self._connection.write(pack(">i", 0)) - else: - for entry in self._env._types: - if isinstance(value, entry[1]): - self._connection.write(entry[0]) - self._connection.write(CustomTypeSerializer(entry[2]).serialize(value)) - return - raise Exception("Unsupported Type encountered.") \ No newline at end of file + return b"" \ No newline at end of file diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Connection.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Connection.py index 680f49529a21a..293f5e9327021 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Connection.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Connection.py @@ -115,6 +115,8 @@ def read(self, des_size, ignored=None): self._read_buffer() old_offset = self._input_offset self._input_offset += des_size + if self._input_offset > self._input_size: + raise Exception("BufferUnderFlowException") return self._input[old_offset:self._input_offset] def _read_buffer(self): @@ -140,6 +142,12 @@ def reset(self): self._input_offset = 0 self._input = b"" + def read_secondary(self, des_size): + return recv_all(self._socket, des_size) + + def write_secondary(self, data): + self._socket.send(data) + class TwinBufferingTCPMappedFileConnection(BufferingTCPMappedFileConnection): def __init__(self, input_file, output_file, port): diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Constants.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Constants.py index 0ca2232629529..01a16fa3880c4 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Constants.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Constants.py @@ -18,14 +18,15 @@ class Types(object): - TYPE_TUPLE = b'\x0B' - TYPE_BOOLEAN = b'\x0A' - TYPE_BYTE = b'\x09' - TYPE_SHORT = b'\x08' - TYPE_INTEGER = b'\x07' - TYPE_LONG = b'\x06' - TYPE_DOUBLE = b'\x04' - TYPE_FLOAT = b'\x05' - TYPE_STRING = b'\x02' - TYPE_NULL = b'\x00' - TYPE_BYTES = b'\x01' + TYPE_ARRAY = b'\x3F' + TYPE_KEY_VALUE = b'\x3E' + TYPE_VALUE_VALUE = b'\x3D' + TYPE_BOOLEAN = b'\x22' + TYPE_BYTE = b'\x21' + TYPE_INTEGER = b'\x20' + TYPE_LONG = b'\x1F' + TYPE_DOUBLE = b'\x1E' + TYPE_FLOAT = b'\x1D' + TYPE_STRING = b'\x1C' + TYPE_BYTES = b'\x1B' + TYPE_NULL = b'\x1A' diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py index 3425cfa1fd953..25321949455db 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py @@ -26,6 +26,7 @@ from flink.connection.Constants import Types +#=====Iterator========================================================================================================== class ListIterator(defIter.Iterator): def __init__(self, values): super(ListIterator, self).__init__() @@ -76,7 +77,7 @@ def next(self): else: self.cur = None self.empty = True - return tmp + return tmp[1] else: raise StopIteration @@ -93,7 +94,7 @@ def next_group(self): self.empty = False def _extract_keys(self, x): - return [x[k] for k in self.keys] + return [x[0][k] for k in self.keys] def _extract_keys_id(self, x): return x @@ -175,6 +176,7 @@ def __init__(self, con, env, group=0): self._group = group self._deserializer = None self._env = env + self._size = 0 def __next__(self): return self.next() @@ -184,8 +186,35 @@ def _read(self, des_size): def next(self): if self.has_next(): + custom_types = self._env._types + read = self._read if self._deserializer is None: - self._deserializer = _get_deserializer(self._group, self._connection.read, self._env._types) + type = read(1) + if type == Types.TYPE_ARRAY: + key_des = _get_deserializer(read, custom_types) + self._deserializer = ArrayDeserializer(key_des) + return key_des.deserialize(read) + elif type == Types.TYPE_KEY_VALUE: + size = ord(read(1)) + key_des = [] + keys = [] + for _ in range(size): + new_d = _get_deserializer(read, custom_types) + key_des.append(new_d) + keys.append(new_d.deserialize(read)) + val_des = _get_deserializer(read, custom_types) + val = val_des.deserialize(read) + self._deserializer = KeyValueDeserializer(key_des, val_des) + return (tuple(keys), val) + elif type == Types.TYPE_VALUE_VALUE: + des1 = _get_deserializer(read, custom_types) + field1 = des1.deserialize(read) + des2 = _get_deserializer(read, custom_types) + field2 = des2.deserialize(read) + self._deserializer = ValueValueDeserializer(des1, des2) + return (field1, field2) + else: + raise Exception("Invalid type ID encountered: " + str(ord(type))) return self._deserializer.deserialize(self._read) else: raise StopIteration @@ -197,6 +226,16 @@ def _reset(self): self._deserializer = None +class PlanIterator(object): + def __init__(self, con, env): + self._connection = con + self._env = env + + def next(self): + deserializer = _get_deserializer(self._connection.read, self._env._types) + return deserializer.deserialize(self._connection.read) + + class DummyIterator(Iterator): def __init__(self): super(Iterator, self).__init__() @@ -211,12 +250,11 @@ def has_next(self): return False -def _get_deserializer(group, read, custom_types, type=None): - if type is None: - type = read(1, group) - return _get_deserializer(group, read, custom_types, type) - elif type == Types.TYPE_TUPLE: - return TupleDeserializer(read, group, custom_types) +#=====Deserializer====================================================================================================== +def _get_deserializer(read, custom_types): + type = read(1) + if 0 < ord(type) < 26: + return TupleDeserializer([_get_deserializer(read, custom_types) for _ in range(ord(type))]) elif type == Types.TYPE_BYTE: return ByteDeserializer() elif type == Types.TYPE_BYTES: @@ -238,101 +276,125 @@ def _get_deserializer(group, read, custom_types, type=None): else: for entry in custom_types: if type == entry[0]: - return entry[3] - raise Exception("Unable to find deserializer for type ID " + str(type)) + return CustomTypeDeserializer(entry[3]) + raise Exception("Unable to find deserializer for type ID " + str(ord(type))) -class TupleDeserializer(object): - def __init__(self, read, group, custom_types): - size = unpack(">I", read(4, group))[0] - self.deserializer = [_get_deserializer(group, read, custom_types) for _ in range(size)] +class Deserializer(object): + def get_type_info_size(self): + return 1 def deserialize(self, read): - return tuple([s.deserialize(read) for s in self.deserializer]) + pass + +class ArrayDeserializer(Deserializer): + def __init__(self, deserializer): + self._deserializer = deserializer + self._d_skip = deserializer.get_type_info_size() -class ByteDeserializer(object): + def deserialize(self, read): + read(1) #array type + read(self._d_skip) + return self._deserializer.deserialize(read) + + +class KeyValueDeserializer(Deserializer): + def __init__(self, key_deserializer, value_deserializer): + self._key_deserializer = [(k, k.get_type_info_size()) for k in key_deserializer] + self._value_deserializer = value_deserializer + self._value_deserializer_skip = value_deserializer.get_type_info_size() + + def deserialize(self, read): + fields = [] + read(1) #key value type + read(1) #key count + for dk in self._key_deserializer: + read(dk[1]) + fields.append(dk[0].deserialize(read)) + dv = self._value_deserializer + read(self._value_deserializer_skip) + return (tuple(fields), dv.deserialize(read)) + + +class ValueValueDeserializer(Deserializer): + def __init__(self, d1, d2): + self._d1 = d1 + self._d1_skip = self._d1.get_type_info_size() + self._d2 = d2 + self._d2_skip = self._d2.get_type_info_size() + + def deserialize(self, read): + read(1) + read(self._d1_skip) + f1 = self._d1.deserialize(read) + read(self._d2_skip) + f2 = self._d2.deserialize(read) + return (f1, f2) + + +class CustomTypeDeserializer(Deserializer): + def __init__(self, deserializer): + self._deserializer = deserializer + + def deserialize(self, read): + read(4) #discard binary size + return self._deserializer.deserialize(read) + + +class TupleDeserializer(Deserializer): + def __init__(self, deserializer): + self._deserializer = deserializer + + def get_type_info_size(self): + return 1 + sum([d.get_type_info_size() for d in self._deserializer]) + + def deserialize(self, read): + return tuple([s.deserialize(read) for s in self._deserializer]) + + +class ByteDeserializer(Deserializer): def deserialize(self, read): return unpack(">c", read(1))[0] -class ByteArrayDeserializer(object): +class ByteArrayDeserializer(Deserializer): def deserialize(self, read): size = unpack(">i", read(4))[0] return bytearray(read(size)) if size else bytearray(b"") -class BooleanDeserializer(object): +class BooleanDeserializer(Deserializer): def deserialize(self, read): return unpack(">?", read(1))[0] -class FloatDeserializer(object): +class FloatDeserializer(Deserializer): def deserialize(self, read): return unpack(">f", read(4))[0] -class DoubleDeserializer(object): +class DoubleDeserializer(Deserializer): def deserialize(self, read): return unpack(">d", read(8))[0] -class IntegerDeserializer(object): +class IntegerDeserializer(Deserializer): def deserialize(self, read): return unpack(">i", read(4))[0] -class LongDeserializer(object): +class LongDeserializer(Deserializer): def deserialize(self, read): return unpack(">q", read(8))[0] -class StringDeserializer(object): +class StringDeserializer(Deserializer): def deserialize(self, read): length = unpack(">i", read(4))[0] return read(length).decode("utf-8") if length else "" -class NullDeserializer(object): - def deserialize(self): +class NullDeserializer(Deserializer): + def deserialize(self, read): return None - - -class TypedIterator(object): - def __init__(self, con, env): - self._connection = con - self._env = env - - def next(self): - read = self._connection.read - type = read(1) - if type == Types.TYPE_TUPLE: - size = unpack(">i", read(4))[0] - return tuple([self.next() for x in range(size)]) - elif type == Types.TYPE_BYTE: - return unpack(">c", read(1))[0] - elif type == Types.TYPE_BYTES: - size = unpack(">i", read(4))[0] - return bytearray(read(size)) if size else bytearray(b"") - elif type == Types.TYPE_BOOLEAN: - return unpack(">?", read(1))[0] - elif type == Types.TYPE_FLOAT: - return unpack(">f", read(4))[0] - elif type == Types.TYPE_DOUBLE: - return unpack(">d", read(8))[0] - elif type == Types.TYPE_INTEGER: - return unpack(">i", read(4))[0] - elif type == Types.TYPE_LONG: - return unpack(">q", read(8))[0] - elif type == Types.TYPE_STRING: - length = unpack(">i", read(4))[0] - return read(length).decode("utf-8") if length else "" - elif type == Types.TYPE_NULL: - return None - else: - for entry in self._env._types: - if type == entry[0]: - return entry[3]() - raise Exception("Unable to find deserializer for type ID " + str(type)) - - diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery10.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery10.py index cc9e7cf1963d9..c1079875d9bbd 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery10.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery10.py @@ -76,7 +76,7 @@ def join(self, value1, value2): STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING], '\n', '|') \ .project(0,5,6,8) \ .filter(LineitemFilter()) \ - .map(ComputeRevenue(), [INT, FLOAT]) + .map(ComputeRevenue()) nation = env \ .read_csv(sys.argv[4], [INT, STRING, INT, STRING], '\n', '|') \ diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery3.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery3.py index 3eb72c9792f77..aaa4e55cf91c9 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery3.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TPCHQuery3.py @@ -87,13 +87,13 @@ def join(self, value1, value2): .join(order) \ .where(0) \ .equal_to(1) \ - .using(CustomerOrderJoin(),[INT, FLOAT, STRING, INT]) + .using(CustomerOrderJoin()) result = customerWithOrder \ .join(lineitem) \ .where(0) \ .equal_to(0) \ - .using(CustomerOrderLineitemJoin(), [INT, FLOAT, STRING, INT]) \ + .using(CustomerOrderLineitemJoin()) \ .group_by(0, 2, 3) \ .reduce(SumReducer()) diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TriangleEnumeration.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TriangleEnumeration.py index b1b3ef4325ab0..f13cc0418687c 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TriangleEnumeration.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/TriangleEnumeration.py @@ -16,7 +16,7 @@ # limitations under the License. ################################################################################ from flink.plan.Environment import get_environment -from flink.plan.Constants import INT, Order +from flink.plan.Constants import Order from flink.functions.FlatMapFunction import FlatMapFunction from flink.functions.GroupReduceFunction import GroupReduceFunction from flink.functions.ReduceFunction import ReduceFunction @@ -123,27 +123,27 @@ def join(self, value1, value2): (1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 5), (3, 4), (3, 7), (3, 8), (5, 6), (7, 8)) edges_with_degrees = edges \ - .flat_map(EdgeDuplicator(), [INT, INT]) \ + .flat_map(EdgeDuplicator()) \ .group_by(0) \ .sort_group(1, Order.ASCENDING) \ - .reduce_group(DegreeCounter(), [INT, INT, INT, INT]) \ + .reduce_group(DegreeCounter()) \ .group_by(0, 2) \ .reduce(DegreeJoiner()) edges_by_degree = edges_with_degrees \ - .map(EdgeByDegreeProjector(), [INT, INT]) + .map(EdgeByDegreeProjector()) edges_by_id = edges_by_degree \ - .map(EdgeByIdProjector(), [INT, INT]) + .map(EdgeByIdProjector()) triangles = edges_by_degree \ .group_by(0) \ .sort_group(1, Order.ASCENDING) \ - .reduce_group(TriadBuilder(), [INT, INT, INT]) \ + .reduce_group(TriadBuilder()) \ .join(edges_by_id) \ .where(1, 2) \ .equal_to(0, 1) \ - .using(TriadFilter(), [INT, INT, INT]) + .using(TriadFilter()) triangles.output() diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WebLogAnalysis.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WebLogAnalysis.py index 676043fdc873c..1ea3e78e1dbcf 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WebLogAnalysis.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WebLogAnalysis.py @@ -19,7 +19,7 @@ from datetime import datetime from flink.plan.Environment import get_environment -from flink.plan.Constants import INT, STRING, FLOAT, WriteMode +from flink.plan.Constants import WriteMode from flink.functions.CoGroupFunction import CoGroupFunction from flink.functions.FilterFunction import FilterFunction @@ -54,16 +54,16 @@ def co_group(self, iterator1, iterator2, collector): sys.exit("Usage: ./bin/pyflink.sh WebLogAnalysis ") documents = env \ - .read_csv(sys.argv[1], [STRING, STRING], "\n", "|") \ + .read_csv(sys.argv[1], "\n", "|") \ .filter(DocumentFilter()) \ .project(0) ranks = env \ - .read_csv(sys.argv[2], [INT, STRING, INT], "\n", "|") \ + .read_csv(sys.argv[2], "\n", "|") \ .filter(RankFilter()) visits = env \ - .read_csv(sys.argv[3], [STRING, STRING, STRING, FLOAT, STRING, STRING, STRING, STRING, INT], "\n", "|") \ + .read_csv(sys.argv[3], "\n", "|") \ .project(1,2) \ .filter(VisitFilter()) \ .project(0) @@ -78,7 +78,7 @@ def co_group(self, iterator1, iterator2, collector): .co_group(visits) \ .where(1) \ .equal_to(0) \ - .using(AntiJoinVisits(), [INT, STRING, INT]) + .using(AntiJoinVisits()) result.write_csv(sys.argv[4], '\n', '|', WriteMode.OVERWRITE) diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WordCount.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WordCount.py index 71c2e28337183..2ab724c384b04 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WordCount.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/example/WordCount.py @@ -18,7 +18,6 @@ import sys from flink.plan.Environment import get_environment -from flink.plan.Constants import INT, STRING from flink.functions.FlatMapFunction import FlatMapFunction from flink.functions.GroupReduceFunction import GroupReduceFunction @@ -47,9 +46,9 @@ def reduce(self, iterator, collector): data = env.from_elements("hello","world","hello","car","tree","data","hello") result = data \ - .flat_map(Tokenizer(), (INT, STRING)) \ + .flat_map(Tokenizer()) \ .group_by(1) \ - .reduce_group(Adder(), (INT, STRING), combinable=True) \ + .reduce_group(Adder(), combinable=True) \ if len(sys.argv) == 3: result.write_csv(sys.argv[2]) diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py index 9c55787ae243a..4cb337ab4a2e2 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py @@ -25,13 +25,16 @@ def __init__(self): self._keys1 = None self._keys2 = None - def _configure(self, input_file, output_file, port, env): + def _configure(self, input_file, output_file, port, env, info): self._connection = Connection.TwinBufferingTCPMappedFileConnection(input_file, output_file, port) self._iterator = Iterator.Iterator(self._connection, env, 0) self._iterator2 = Iterator.Iterator(self._connection, env, 1) self._cgiter = Iterator.CoGroupIterator(self._iterator, self._iterator2, self._keys1, self._keys2) + self._collector = Collector.Collector(self._connection, env, info) self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) - self._configure_chain(Collector.Collector(self._connection, env)) + if info.chained_info is not None: + info.chained_info.operator._configure_chain(self.context, self._collector, info.chained_info) + self._collector = info.chained_info.operator def _run(self): collector = self._collector diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py index f874a25261a80..dfe6a283acf4c 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py @@ -16,9 +16,9 @@ # limitations under the License. ################################################################################ from abc import ABCMeta, abstractmethod -import sys from collections import deque from flink.connection import Connection, Iterator, Collector +from flink.connection.Iterator import IntegerDeserializer, StringDeserializer, _get_deserializer from flink.functions import RuntimeContext @@ -30,55 +30,56 @@ def __init__(self): self._iterator = None self._collector = None self.context = None - self._chain_operator = None + self._env = None - def _configure(self, input_file, output_file, port, env): + def _configure(self, input_file, output_file, port, env, info): self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port) self._iterator = Iterator.Iterator(self._connection, env) + self._collector = Collector.Collector(self._connection, env, info) self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) - self._configure_chain(Collector.Collector(self._connection, env)) + self._env = env + if info.chained_info is not None: + info.chained_info.operator._configure_chain(self.context, self._collector, info.chained_info) + self._collector = info.chained_info.operator - def _configure_chain(self, collector): - if self._chain_operator is not None: - self._collector = self._chain_operator - self._collector.context = self.context - self._collector._configure_chain(collector) - self._collector._open() - else: + def _configure_chain(self, context, collector, info): + self.context = context + if info.chained_info is None: self._collector = collector - - def _chain(self, operator): - self._chain_operator = operator + else: + self._collector = info.chained_info.operator + info.chained_info.operator._configure_chain(context, collector, info.chained_info) @abstractmethod def _run(self): pass - def _open(self): - pass - def _close(self): self._collector._close() - self._connection.close() + if self._connection is not None: + self._connection.close() def _go(self): self._receive_broadcast_variables() self._run() def _receive_broadcast_variables(self): - broadcast_count = self._iterator.next() - self._iterator._reset() - self._connection.reset() + con = self._connection + deserializer_int = IntegerDeserializer() + broadcast_count = deserializer_int.deserialize(con.read_secondary) + deserializer_string = StringDeserializer() for _ in range(broadcast_count): - name = self._iterator.next() - self._iterator._reset() - self._connection.reset() + name = deserializer_string.deserialize(con.read_secondary) bc = deque() - while(self._iterator.has_next()): - bc.append(self._iterator.next()) + if con.read_secondary(1) == b"\x01": + serializer_data = _get_deserializer(con.read_secondary, self._env._types) + value = serializer_data.deserialize(con.read_secondary) + bc.append(value) + while con.read_secondary(1) == b"\x01": + con.read_secondary(serializer_data.get_type_info_size()) #skip type info + value = serializer_data.deserialize(con.read_secondary) + bc.append(value) self.context._add_broadcast_variable(name, bc) - self._iterator._reset() - self._connection.reset() diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py index b758c19a1167b..340497da64c1e 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py @@ -24,21 +24,14 @@ class GroupReduceFunction(Function.Function): def __init__(self): super(GroupReduceFunction, self).__init__() - self._keys = None - def _configure(self, input_file, output_file, port, env): - self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port) - self._iterator = Iterator.Iterator(self._connection, env) - if self._keys is None: + def _configure(self, input_file, output_file, port, env, info): + super(GroupReduceFunction, self)._configure(input_file, output_file, port, env, info) + if info.key1 is None: self._run = self._run_all_group_reduce else: self._run = self._run_grouped_group_reduce - self._group_iterator = Iterator.GroupIterator(self._iterator, self._keys) - self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) - self._collector = Collector.Collector(self._connection, env) - - def _set_grouping_keys(self, keys): - self._keys = keys + self._group_iterator = Iterator.GroupIterator(self._iterator, info.key1) def _run(self): pass diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/KeySelectorFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/KeySelectorFunction.py new file mode 100644 index 0000000000000..a75961f9341aa --- /dev/null +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/KeySelectorFunction.py @@ -0,0 +1,28 @@ +# ############################################################################### +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + + +class KeySelectorFunction: + def __call__(self, value): + return self.get_key(value) + + def callable(self): + return True + + def get_key(self, value): + pass diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py index 45a22da75c0f7..95e8b8a1df69d 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py @@ -23,21 +23,14 @@ class ReduceFunction(Function.Function): def __init__(self): super(ReduceFunction, self).__init__() - self._keys = None - def _configure(self, input_file, output_file, port, env): - self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port) - self._iterator = Iterator.Iterator(self._connection, env) - if self._keys is None: + def _configure(self, input_file, output_file, port, env, info): + super(ReduceFunction, self)._configure(input_file, output_file, port, env, info) + if info.key1 is None: self._run = self._run_all_reduce else: self._run = self._run_grouped_reduce - self._group_iterator = Iterator.GroupIterator(self._iterator, self._keys) - self._collector = Collector.Collector(self._connection, env) - self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) - - def _set_grouping_keys(self, keys): - self._keys = keys + self._group_iterator = Iterator.GroupIterator(self._iterator, info.key1) def _run(self): pass diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py index 0c9fe80efe8d4..888146348ddde 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py @@ -63,11 +63,6 @@ class Order(object): PY2 = sys.version_info[0] == 2 PY3 = sys.version_info[0] == 3 - -class _Dummy(object): - pass - - if PY2: BOOL = True INT = 1 @@ -75,11 +70,17 @@ class _Dummy(object): FLOAT = 2.5 STRING = "type" BYTES = bytearray(b"byte") - CUSTOM = _Dummy() elif PY3: BOOL = True INT = 1 FLOAT = 2.5 STRING = "type" BYTES = bytearray(b"byte") - CUSTOM = _Dummy() + + +def _createKeyValueTypeInfo(keyCount): + return (tuple([BYTES for _ in range(keyCount)]), BYTES) + + +def _createArrayTypeInfo(): + return BYTES diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/DataSet.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/DataSet.py index 25ec8b8a6c3e1..eda8d025da9b1 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/DataSet.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/DataSet.py @@ -15,11 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ -import inspect import copy import types as TYPES -from flink.plan.Constants import _Identifier, WriteMode, STRING +from flink.plan.Constants import _Identifier, WriteMode, _createKeyValueTypeInfo, _createArrayTypeInfo from flink.plan.OperationInfo import OperationInfo from flink.functions.CoGroupFunction import CoGroupFunction from flink.functions.FilterFunction import FilterFunction @@ -30,52 +29,33 @@ from flink.functions.MapFunction import MapFunction from flink.functions.MapPartitionFunction import MapPartitionFunction from flink.functions.ReduceFunction import ReduceFunction +from flink.functions.KeySelectorFunction import KeySelectorFunction -def deduct_output_type(dataset): - skip = set([_Identifier.GROUP, _Identifier.SORT, _Identifier.UNION]) - source = set([_Identifier.SOURCE_CSV, _Identifier.SOURCE_TEXT, _Identifier.SOURCE_VALUE]) - default = set([_Identifier.CROSS, _Identifier.CROSSH, _Identifier.CROSST, _Identifier.JOINT, _Identifier.JOINH, _Identifier.JOIN]) - - while True: - dataset_type = dataset.identifier - if dataset_type in skip: - dataset = dataset.parent - continue - if dataset_type in source: - if dataset_type == _Identifier.SOURCE_TEXT: - return STRING - if dataset_type == _Identifier.SOURCE_VALUE: - return dataset.values[0] - if dataset_type == _Identifier.SOURCE_CSV: - return dataset.types - if dataset_type == _Identifier.PROJECTION: - return tuple([deduct_output_type(dataset.parent)[k] for k in dataset.keys]) - if dataset_type in default: - if dataset.operator is not None: #udf-join/cross - return dataset.types - if len(dataset.projections) == 0: #defaultjoin/-cross - return (deduct_output_type(dataset.parent), deduct_output_type(dataset.other)) - else: #projectjoin/-cross - t1 = deduct_output_type(dataset.parent) - t2 = deduct_output_type(dataset.other) - out_type = [] - for prj in dataset.projections: - if len(prj[1]) == 0: #projection on non-tuple dataset - if prj[0] == "first": - out_type.append(t1) - else: - out_type.append(t2) - else: #projection on tuple dataset - for key in prj[1]: - if prj[0] == "first": - out_type.append(t1[key]) - else: - out_type.append(t2[key]) - return tuple(out_type) - return dataset.types - - -class Set(object): + +class Stringify(MapFunction): + def map(self, value): + if isinstance(value, (tuple, list)): + return "(" + b", ".join([self.map(x) for x in value]) + ")" + else: + return str(value) + + +class CsvStringify(MapFunction): + def __init__(self, f_delim): + super(CsvStringify, self).__init__() + self.delim = f_delim + + def map(self, value): + return self.delim.join([self._map(field) for field in value]) + + def _map(self, value): + if isinstance(value, (tuple, list)): + return "(" + b", ".join([self.map(x) for x in value]) + ")" + else: + return str(value) + + +class DataSet(object): def __init__(self, env, info): self._env = env self._info = info @@ -86,6 +66,9 @@ def output(self, to_error=False): """ Writes a DataSet to the standard output stream (stdout). """ + self.map(Stringify())._output(to_error) + + def _output(self, to_error): child = OperationInfo() child.identifier = _Identifier.SINK_PRINT child.parent = self._info @@ -100,6 +83,9 @@ def write_text(self, path, write_mode=WriteMode.NO_OVERWRITE): :param path: he path pointing to the location the text file is written to. :param write_mode: OutputFormat.WriteMode value, indicating whether files should be overwritten """ + return self.map(Stringify())._write_text(path, write_mode) + + def _write_text(self, path, write_mode): child = OperationInfo() child.identifier = _Identifier.SINK_TEXT child.parent = self._info @@ -116,6 +102,9 @@ def write_csv(self, path, line_delimiter="\n", field_delimiter=',', write_mode=W :param path: The path pointing to the location the CSV file is written to. :param write_mode: OutputFormat.WriteMode value, indicating whether files should be overwritten """ + return self.map(CsvStringify(field_delimiter))._write_csv(path, line_delimiter, field_delimiter, write_mode) + + def _write_csv(self, path, line_delimiter, field_delimiter, write_mode): child = OperationInfo() child.identifier = _Identifier.SINK_CSV child.path = path @@ -126,7 +115,7 @@ def write_csv(self, path, line_delimiter="\n", field_delimiter=',', write_mode=W self._info.sinks.append(child) self._env._sinks.append(child) - def reduce_group(self, operator, types, combinable=False): + def reduce_group(self, operator, combinable=False): """ Applies a GroupReduce transformation. @@ -136,7 +125,6 @@ def reduce_group(self, operator, types, combinable=False): emit any number of output elements including none. :param operator: The GroupReduceFunction that is applied on the DataSet. - :param types: The type of the resulting DataSet. :return:A GroupReduceOperator that represents the reduced DataSet. """ if isinstance(operator, TYPES.FunctionType): @@ -148,17 +136,12 @@ def reduce_group(self, operator, types, combinable=False): child.identifier = _Identifier.GROUPREDUCE child.parent = self._info child.operator = operator - child.types = types + child.types = _createArrayTypeInfo() child.name = "PythonGroupReduce" self._info.children.append(child) self._env._sets.append(child) return child_set - -class ReduceSet(Set): - def __init__(self, env, info): - super(ReduceSet, self).__init__(env, info) - def reduce(self, operator): """ Applies a Reduce transformation on a non-grouped DataSet. @@ -179,16 +162,11 @@ def reduce(self, operator): child.parent = self._info child.operator = operator child.name = "PythonReduce" - child.types = deduct_output_type(self._info) + child.types = _createArrayTypeInfo() self._info.children.append(child) self._env._sets.append(child) return child_set - -class DataSet(ReduceSet): - def __init__(self, env, info): - super(DataSet, self).__init__(env, info) - def project(self, *fields): """ Applies a Project transformation on a Tuple DataSet. @@ -201,14 +179,7 @@ def project(self, *fields): :return: The projected DataSet. """ - child = OperationInfo() - child_set = DataSet(self._env, child) - child.identifier = _Identifier.PROJECTION - child.parent = self._info - child.keys = fields - self._info.children.append(child) - self._env._sets.append(child) - return child_set + return self.map(lambda x: tuple([x[key] for key in fields])) def group_by(self, *keys): """ @@ -223,6 +194,9 @@ def group_by(self, *keys): :param keys: One or more field positions on which the DataSet will be grouped. :return:A Grouping on which a transformation needs to be applied to obtain a transformed DataSet. """ + return self.map(lambda x: x)._group_by(keys) + + def _group_by(self, keys): child = OperationInfo() child_chain = [] child_set = UnsortedGrouping(self._env, child, child_chain) @@ -251,9 +225,8 @@ def co_group(self, other_set): other_set._info.children.append(child) child_set = CoGroupOperatorWhere(self._env, child) child.identifier = _Identifier.COGROUP - child.parent = self._info - child.other = other_set._info - self._info.children.append(child) + child.parent_set = self + child.other_set = other_set return child_set def cross(self, other_set): @@ -323,14 +296,13 @@ def filter(self, operator): child.identifier = _Identifier.FILTER child.parent = self._info child.operator = operator - child.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__) child.name = "PythonFilter" - child.types = deduct_output_type(self._info) + child.types = _createArrayTypeInfo() self._info.children.append(child) self._env._sets.append(child) return child_set - def flat_map(self, operator, types): + def flat_map(self, operator): """ Applies a FlatMap transformation on a DataSet. @@ -338,7 +310,6 @@ def flat_map(self, operator, types): Each FlatMapFunction call can return any number of elements including none. :param operator: The FlatMapFunction that is called for each element of the DataSet. - :param types: The type of the resulting DataSet. :return:A FlatMapOperator that represents the transformed DataSe """ if isinstance(operator, TYPES.FunctionType): @@ -350,8 +321,7 @@ def flat_map(self, operator, types): child.identifier = _Identifier.FLATMAP child.parent = self._info child.operator = operator - child.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__) - child.types = types + child.types = _createArrayTypeInfo() child.name = "PythonFlatMap" self._info.children.append(child) self._env._sets.append(child) @@ -398,14 +368,11 @@ def _join(self, other_set, identifier): child = OperationInfo() child_set = JoinOperatorWhere(self._env, child) child.identifier = identifier - child.parent = self._info - child.other = other_set._info - self._info.children.append(child) - other_set._info.children.append(child) - self._env._sets.append(child) + child.parent_set = self + child.other_set = other_set return child_set - def map(self, operator, types): + def map(self, operator): """ Applies a Map transformation on a DataSet. @@ -413,7 +380,6 @@ def map(self, operator, types): Each MapFunction call returns exactly one element. :param operator: The MapFunction that is called for each element of the DataSet. - :param types: The type of the resulting DataSet :return:A MapOperator that represents the transformed DataSet """ if isinstance(operator, TYPES.FunctionType): @@ -425,14 +391,13 @@ def map(self, operator, types): child.identifier = _Identifier.MAP child.parent = self._info child.operator = operator - child.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__) - child.types = types + child.types = _createArrayTypeInfo() child.name = "PythonMap" self._info.children.append(child) self._env._sets.append(child) return child_set - def map_partition(self, operator, types): + def map_partition(self, operator): """ Applies a MapPartition transformation on a DataSet. @@ -444,7 +409,6 @@ def map_partition(self, operator, types): sees is non deterministic and depends on the degree of parallelism of the operation. :param operator: The MapFunction that is called for each element of the DataSet. - :param types: The type of the resulting DataSet :return:A MapOperator that represents the transformed DataSet """ if isinstance(operator, TYPES.FunctionType): @@ -456,8 +420,7 @@ def map_partition(self, operator, types): child.identifier = _Identifier.MAPPARTITION child.parent = self._info child.operator = operator - child.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__) - child.types = types + child.types = _createArrayTypeInfo() child.name = "PythonMapPartition" self._info.children.append(child) self._env._sets.append(child) @@ -506,7 +469,10 @@ def __init__(self, env, info, child_chain): info.id = env._counter env._counter += 1 - def reduce_group(self, operator, types, combinable=False): + def _finalize(self): + pass + + def reduce_group(self, operator, combinable=False): """ Applies a GroupReduce transformation. @@ -516,21 +482,21 @@ def reduce_group(self, operator, types, combinable=False): emit any number of output elements including none. :param operator: The GroupReduceFunction that is applied on the DataSet. - :param types: The type of the resulting DataSet. :return:A GroupReduceOperator that represents the reduced DataSet. """ + self._finalize() if isinstance(operator, TYPES.FunctionType): f = operator operator = GroupReduceFunction() operator.reduce = f - operator._set_grouping_keys(self._child_chain[0].keys) child = OperationInfo() child_set = OperatorSet(self._env, child) child.identifier = _Identifier.GROUPREDUCE child.parent = self._info child.operator = operator - child.types = types + child.types = _createArrayTypeInfo() child.name = "PythonGroupReduce" + child.key1 = self._child_chain[0].keys self._info.children.append(child) self._env._sets.append(child) @@ -573,26 +539,96 @@ def reduce(self, operator): :param operator:The ReduceFunction that is applied on the DataSet. :return:A ReduceOperator that represents the reduced DataSet. """ - operator._set_grouping_keys(self._child_chain[0].keys) - for i in self._child_chain: - self._env._sets.append(i) + self._finalize() + if isinstance(operator, TYPES.FunctionType): + f = operator + operator = ReduceFunction() + operator.reduce = f child = OperationInfo() child_set = OperatorSet(self._env, child) child.identifier = _Identifier.REDUCE child.parent = self._info child.operator = operator child.name = "PythonReduce" - child.types = deduct_output_type(self._info) + child.types = _createArrayTypeInfo() + child.key1 = self._child_chain[0].keys self._info.children.append(child) self._env._sets.append(child) return child_set + def _finalize(self): + grouping = self._child_chain[0] + keys = grouping.keys + f = None + if isinstance(keys[0], TYPES.FunctionType): + f = lambda x: (keys[0](x),) + if isinstance(keys[0], KeySelectorFunction): + f = lambda x: (keys[0].get_key(x),) + if f is None: + f = lambda x: tuple([x[key] for key in keys]) + + grouping.parent.operator.map = lambda x: (f(x), x) + grouping.parent.types = _createKeyValueTypeInfo(len(keys)) + grouping.keys = tuple([i for i in range(len(grouping.keys))]) + class SortedGrouping(Grouping): def __init__(self, env, info, child_chain): super(SortedGrouping, self).__init__(env, info, child_chain) + def _finalize(self): + grouping = self._child_chain[0] + sortings = self._child_chain[1:] + + #list of used index keys to prevent duplicates and determine final index + index_keys = set() + + if not isinstance(grouping.keys[0], (TYPES.FunctionType, KeySelectorFunction)): + index_keys = index_keys.union(set(grouping.keys)) + + #list of sorts using indices + index_sorts = [] + #list of sorts using functions + ksl_sorts = [] + for s in sortings: + if not isinstance(s.field, (TYPES.FunctionType, KeySelectorFunction)): + index_keys.add(s.field) + index_sorts.append(s) + else: + ksl_sorts.append(s) + + used_keys = sorted(index_keys) + #all data gathered + + #construct list of extractor lambdas + lambdas = [] + i = 0 + for key in used_keys: + lambdas.append(lambda x, k=key: x[k]) + i += 1 + if isinstance(grouping.keys[0], (TYPES.FunctionType, KeySelectorFunction)): + lambdas.append(grouping.keys[0]) + for ksl_op in ksl_sorts: + lambdas.append(ksl_op.field) + + grouping.parent.operator.map = lambda x: (tuple([l(x) for l in lambdas]), x) + grouping.parent.types = _createKeyValueTypeInfo(len(lambdas)) + #modify keys + ksl_offset = len(used_keys) + if not isinstance(grouping.keys[0], (TYPES.FunctionType, KeySelectorFunction)): + grouping.keys = tuple([used_keys.index(key) for key in grouping.keys]) + else: + grouping.keys = (ksl_offset,) + ksl_offset += 1 + + for iop in index_sorts: + iop.field = used_keys.index(iop.field) + + for kop in ksl_sorts: + kop.field = ksl_offset + ksl_offset += 1 + class CoGroupOperatorWhere(object): def __init__(self, env, info): @@ -609,6 +645,18 @@ def where(self, *fields): :param fields: The indexes of the Tuple fields of the first co-grouped DataSets that should be used as keys. :return: An incomplete CoGroup transformation. """ + f = None + if isinstance(fields[0], TYPES.FunctionType): + f = lambda x: (fields[0](x),) + if isinstance(fields[0], KeySelectorFunction): + f = lambda x: (fields[0].get_key(x),) + if f is None: + f = lambda x: tuple([x[key] for key in fields]) + + new_parent_set = self._info.parent_set.map(lambda x: (f(x), x)) + new_parent_set._info.types = _createKeyValueTypeInfo(len(fields)) + self._info.parent = new_parent_set._info + self._info.parent.children.append(self._info) self._info.key1 = fields return CoGroupOperatorTo(self._env, self._info) @@ -628,6 +676,18 @@ def equal_to(self, *fields): :param fields: The indexes of the Tuple fields of the second co-grouped DataSet that should be used as keys. :return: An incomplete CoGroup transformation. """ + f = None + if isinstance(fields[0], TYPES.FunctionType): + f = lambda x: (fields[0](x),) + if isinstance(fields[0], KeySelectorFunction): + f = lambda x: (fields[0].get_key(x),) + if f is None: + f = lambda x: tuple([x[key] for key in fields]) + + new_other_set = self._info.other_set.map(lambda x: (f(x), x)) + new_other_set._info.types = _createKeyValueTypeInfo(len(fields)) + self._info.other = new_other_set._info + self._info.other.children.append(self._info) self._info.key2 = fields return CoGroupOperatorUsing(self._env, self._info) @@ -637,7 +697,7 @@ def __init__(self, env, info): self._env = env self._info = info - def using(self, operator, types): + def using(self, operator): """ Finalizes a CoGroup transformation. @@ -645,7 +705,6 @@ def using(self, operator, types): Each CoGroupFunction call returns an arbitrary number of keys. :param operator: The CoGroupFunction that is called for all groups of elements with identical keys. - :param types: The type of the resulting DataSet. :return:An CoGroupOperator that represents the co-grouped result DataSet. """ if isinstance(operator, TYPES.FunctionType): @@ -653,11 +712,12 @@ def using(self, operator, types): operator = CoGroupFunction() operator.co_group = f new_set = OperatorSet(self._env, self._info) + self._info.key1 = tuple([x for x in range(len(self._info.key1))]) + self._info.key2 = tuple([x for x in range(len(self._info.key2))]) operator._keys1 = self._info.key1 operator._keys2 = self._info.key2 self._info.operator = operator - self._info.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__) - self._info.types = types + self._info.types = _createArrayTypeInfo() self._info.name = "PythonCoGroup" self._env._sets.append(self._info) return new_set @@ -679,7 +739,19 @@ def where(self, *fields): :return:An incomplete Join transformation. """ - self._info.key1 = fields + f = None + if isinstance(fields[0], TYPES.FunctionType): + f = lambda x: (fields[0](x),) + if isinstance(fields[0], KeySelectorFunction): + f = lambda x: (fields[0].get_key(x),) + if f is None: + f = lambda x: tuple([x[key] for key in fields]) + + new_parent_set = self._info.parent_set.map(lambda x: (f(x), x)) + new_parent_set._info.types = _createKeyValueTypeInfo(len(fields)) + self._info.parent = new_parent_set._info + self._info.parent.children.append(self._info) + self._info.key1 = tuple([x for x in range(len(fields))]) return JoinOperatorTo(self._env, self._info) @@ -698,81 +770,115 @@ def equal_to(self, *fields): :param fields:The indexes of the Tuple fields of the second join DataSet that should be used as keys. :return:An incomplete Join Transformation. """ - self._info.key2 = fields + f = None + if isinstance(fields[0], TYPES.FunctionType): + f = lambda x: (fields[0](x),) + if isinstance(fields[0], KeySelectorFunction): + f = lambda x: (fields[0].get_key(x),) + if f is None: + f = lambda x: tuple([x[key] for key in fields]) + + new_other_set = self._info.other_set.map(lambda x: (f(x), x)) + new_other_set._info.types = _createKeyValueTypeInfo(len(fields)) + self._info.other = new_other_set._info + self._info.other.children.append(self._info) + self._info.key2 = tuple([x for x in range(len(fields))]) + self._env._sets.append(self._info) return JoinOperator(self._env, self._info) -class JoinOperatorProjection(DataSet): +class Projector(DataSet): def __init__(self, env, info): - super(JoinOperatorProjection, self).__init__(env, info) + super(Projector, self).__init__(env, info) def project_first(self, *fields): """ - Initiates a ProjectJoin transformation. + Initiates a Project transformation. - Projects the first join input. - If the first join input is a Tuple DataSet, fields can be selected by their index. - If the first join input is not a Tuple DataSet, no parameters should be passed. + Projects the first input. + If the first input is a Tuple DataSet, fields can be selected by their index. + If the first input is not a Tuple DataSet, no parameters should be passed. :param fields: The indexes of the selected fields. - :return: An incomplete JoinProjection. + :return: An incomplete Projection. """ - self._info.projections.append(("first", fields)) + for field in fields: + self._info.projections.append((0, field)) + self._info.operator.map = lambda x : tuple([x[side][index] for side, index in self._info.projections]) return self def project_second(self, *fields): """ - Initiates a ProjectJoin transformation. + Initiates a Project transformation. - Projects the second join input. - If the second join input is a Tuple DataSet, fields can be selected by their index. - If the second join input is not a Tuple DataSet, no parameters should be passed. + Projects the second input. + If the second input is a Tuple DataSet, fields can be selected by their index. + If the second input is not a Tuple DataSet, no parameters should be passed. :param fields: The indexes of the selected fields. - :return: An incomplete JoinProjection. + :return: An incomplete Projection. """ - self._info.projections.append(("second", fields)) + for field in fields: + self._info.projections.append((1, field)) + self._info.operator.map = lambda x : tuple([x[side][index] for side, index in self._info.projections]) return self -class JoinOperator(DataSet): - def __init__(self, env, info): - super(JoinOperator, self).__init__(env, info) +class Projectable: + def __init__(self): + pass def project_first(self, *fields): """ - Initiates a ProjectJoin transformation. + Initiates a Project transformation. - Projects the first join input. - If the first join input is a Tuple DataSet, fields can be selected by their index. - If the first join input is not a Tuple DataSet, no parameters should be passed. + Projects the first input. + If the first input is a Tuple DataSet, fields can be selected by their index. + If the first input is not a Tuple DataSet, no parameters should be passed. :param fields: The indexes of the selected fields. - :return: An incomplete JoinProjection. + :return: An incomplete Projection. """ - return JoinOperatorProjection(self._env, self._info).project_first(*fields) + return Projectable._createProjector(self._env, self._info).project_first(*fields) def project_second(self, *fields): """ - Initiates a ProjectJoin transformation. + Initiates a Project transformation. - Projects the second join input. - If the second join input is a Tuple DataSet, fields can be selected by their index. - If the second join input is not a Tuple DataSet, no parameters should be passed. + Projects the second input. + If the second input is a Tuple DataSet, fields can be selected by their index. + If the second input is not a Tuple DataSet, no parameters should be passed. :param fields: The indexes of the selected fields. - :return: An incomplete JoinProjection. + :return: An incomplete Projection. """ - return JoinOperatorProjection(self._env, self._info).project_second(*fields) + return Projectable._createProjector(self._env, self._info).project_second(*fields) + + @staticmethod + def _createProjector(env, info): + child = OperationInfo() + child_set = Projector(env, child) + child.identifier = _Identifier.MAP + child.operator = MapFunction() + child.parent = info + child.types = _createArrayTypeInfo() + child.name = "Projector" + info.children.append(child) + env._sets.append(child) + return child_set + - def using(self, operator, types): +class JoinOperator(DataSet, Projectable): + def __init__(self, env, info): + super(JoinOperator, self).__init__(env, info) + + def using(self, operator): """ Finalizes a Join transformation. Applies a JoinFunction to each pair of joined elements. Each JoinFunction call returns exactly one element. :param operator:The JoinFunction that is called for each pair of joined elements. - :param types: :return:An Set that represents the joined result DataSet. """ if isinstance(operator, TYPES.FunctionType): @@ -780,84 +886,23 @@ def using(self, operator, types): operator = JoinFunction() operator.join = f self._info.operator = operator - self._info.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__) - self._info.types = types + self._info.types = _createArrayTypeInfo() self._info.name = "PythonJoin" - self._env._sets.append(self._info) + self._info.uses_udf = True return OperatorSet(self._env, self._info) -class CrossOperatorProjection(DataSet): - def __init__(self, env, info): - super(CrossOperatorProjection, self).__init__(env, info) - - def project_first(self, *fields): - """ - Initiates a ProjectCross transformation. - - Projects the first join input. - If the first join input is a Tuple DataSet, fields can be selected by their index. - If the first join input is not a Tuple DataSet, no parameters should be passed. - - :param fields: The indexes of the selected fields. - :return: An incomplete CrossProjection. - """ - self._info.projections.append(("first", fields)) - return self - - def project_second(self, *fields): - """ - Initiates a ProjectCross transformation. - - Projects the second join input. - If the second join input is a Tuple DataSet, fields can be selected by their index. - If the second join input is not a Tuple DataSet, no parameters should be passed. - - :param fields: The indexes of the selected fields. - :return: An incomplete CrossProjection. - """ - self._info.projections.append(("second", fields)) - return self - - -class CrossOperator(DataSet): +class CrossOperator(DataSet, Projectable): def __init__(self, env, info): super(CrossOperator, self).__init__(env, info) - def project_first(self, *fields): - """ - Initiates a ProjectCross transformation. - - Projects the first join input. - If the first join input is a Tuple DataSet, fields can be selected by their index. - If the first join input is not a Tuple DataSet, no parameters should be passed. - - :param fields: The indexes of the selected fields. - :return: An incomplete CrossProjection. - """ - return CrossOperatorProjection(self._env, self._info).project_first(*fields) - - def project_second(self, *fields): - """ - Initiates a ProjectCross transformation. - - Projects the second join input. - If the second join input is a Tuple DataSet, fields can be selected by their index. - If the second join input is not a Tuple DataSet, no parameters should be passed. - - :param fields: The indexes of the selected fields. - :return: An incomplete CrossProjection. - """ - return CrossOperatorProjection(self._env, self._info).project_second(*fields) - - def using(self, operator, types): + def using(self, operator): """ Finalizes a Cross transformation. Applies a CrossFunction to each pair of joined elements. Each CrossFunction call returns exactly one element. :param operator:The CrossFunction that is called for each pair of joined elements. - :param types: The type of the resulting DataSet. :return:An Set that represents the joined result DataSet. """ if isinstance(operator, TYPES.FunctionType): @@ -865,7 +910,7 @@ def using(self, operator, types): operator = CrossFunction() operator.cross = f self._info.operator = operator - self._info.meta = str(inspect.getmodule(operator)) + "|" + str(operator.__class__.__name__) - self._info.types = types + self._info.types = _createArrayTypeInfo() self._info.name = "PythonCross" + self._info.uses_udf = True return OperatorSet(self._env, self._info) diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py index a2279c3e763e7..4f2c5e33de9dc 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py @@ -159,8 +159,8 @@ def execute(self, local=False, debug=False): if plan_mode: port = int(sys.stdin.readline().rstrip('\n')) self._connection = Connection.PureTCPConnection(port) - self._iterator = Iterator.TypedIterator(self._connection, self) - self._collector = Collector.TypedCollector(self._connection, self) + self._iterator = Iterator.PlanIterator(self._connection, self) + self._collector = Collector.PlanCollector(self._connection, self) self._send_plan() result = self._receive_result() self._connection.close() @@ -175,13 +175,13 @@ def execute(self, local=False, debug=False): input_path = sys.stdin.readline().rstrip('\n') output_path = sys.stdin.readline().rstrip('\n') + used_set = None operator = None for set in self._sets: if set.id == id: + used_set = set operator = set.operator - if set.id == -id: - operator = set.combineop - operator._configure(input_path, output_path, port, self) + operator._configure(input_path, output_path, port, self, used_set) operator._go() operator._close() sys.stdout.flush() @@ -211,7 +211,7 @@ def _find_chains(self): if child_type in chainable: parent = child.parent if parent.operator is not None and len(parent.children) == 1 and len(parent.sinks) == 0: - parent.operator._chain(child.operator) + parent.chained_info = child parent.name += " -> " + child.name parent.types = child.types for grand_child in child.children: @@ -297,11 +297,8 @@ def _send_operations(self): break if case(_Identifier.CROSS, _Identifier.CROSSH, _Identifier.CROSST): collect(set.other.id) + collect(set.uses_udf) collect(set.types) - collect(len(set.projections)) - for p in set.projections: - collect(p[0]) - collect(p[1]) collect(set.name) break if case(_Identifier.REDUCE, _Identifier.GROUPREDUCE): @@ -312,11 +309,8 @@ def _send_operations(self): collect(set.key1) collect(set.key2) collect(set.other.id) + collect(set.uses_udf) collect(set.types) - collect(len(set.projections)) - for p in set.projections: - collect(p[0]) - collect(p[1]) collect(set.name) break if case(_Identifier.MAP, _Identifier.MAPPARTITION, _Identifier.FLATMAP, _Identifier.FILTER): diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/OperationInfo.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/OperationInfo.py index c47fab57d7c42..3605d7f418384 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/OperationInfo.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/OperationInfo.py @@ -23,6 +23,8 @@ def __init__(self, info=None): if info is None: self.parent = None self.other = None + self.parent_set = None + self.other_set = None self.identifier = None self.field = None self.order = None @@ -31,6 +33,7 @@ def __init__(self, info=None): self.key2 = None self.types = None self.operator = None + self.uses_udf = False self.name = None self.delimiter_line = "\n" self.delimiter_field = "," @@ -43,6 +46,7 @@ def __init__(self, info=None): self.bcvars = [] self.id = None self.to_err = False + self.chained_info = None else: self.__dict__.update(info.__dict__) diff --git a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py index c9bc404265bb2..3e718d317240d 100644 --- a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py +++ b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py @@ -22,7 +22,8 @@ from flink.functions.MapPartitionFunction import MapPartitionFunction from flink.functions.ReduceFunction import ReduceFunction from flink.functions.GroupReduceFunction import GroupReduceFunction -from flink.plan.Constants import INT, STRING, FLOAT, BOOL, BYTES, CUSTOM, Order, WriteMode +from flink.plan.Constants import Order, WriteMode +from flink.plan.Constants import INT, STRING import struct #Utilities @@ -73,7 +74,7 @@ def map_partition(self, iterator, collector): d4 = env.from_elements((1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)) - d5 = env.from_elements((4.4, 4.3, 1), (4.3, 4.4, 1), (4.2, 4.1, 3), (4.1, 4.1, 3)) + d5 = env.from_elements((1, 2.4), (1, 3.7), (1, 0.4), (1, 5.4)) d6 = env.from_elements(1, 1, 12) @@ -89,19 +90,19 @@ def map_partition(self, iterator, collector): #Types env.from_elements(bytearray(b"hello"), bytearray(b"world"))\ - .map(Id(), BYTES).map_partition(Verify([bytearray(b"hello"), bytearray(b"world")], "Byte"), STRING).output() + .map(Id()).map_partition(Verify([bytearray(b"hello"), bytearray(b"world")], "Byte")).output() env.from_elements(1, 2, 3, 4, 5)\ - .map(Id(), INT).map_partition(Verify([1,2,3,4,5], "Int"), STRING).output() + .map(Id()).map_partition(Verify([1,2,3,4,5], "Int")).output() env.from_elements(True, True, False)\ - .map(Id(), BOOL).map_partition(Verify([True, True, False], "Bool"), STRING).output() + .map(Id()).map_partition(Verify([True, True, False], "Bool")).output() env.from_elements(1.4, 1.7, 12312.23)\ - .map(Id(), FLOAT).map_partition(Verify([1.4, 1.7, 12312.23], "Float"), STRING).output() + .map(Id()).map_partition(Verify([1.4, 1.7, 12312.23], "Float")).output() env.from_elements("hello", "world")\ - .map(Id(), STRING).map_partition(Verify(["hello", "world"], "String"), STRING).output() + .map(Id()).map_partition(Verify(["hello", "world"], "String")).output() #Custom Serialization class Ext(MapPartitionFunction): @@ -125,16 +126,16 @@ def deserialize(self, read): env.register_type(MyObj, MySerializer(), MyDeserializer()) env.from_elements(MyObj(2), MyObj(4)) \ - .map(Id(), CUSTOM).map_partition(Ext(), INT) \ - .map_partition(Verify([2, 4], "CustomTypeSerialization"), STRING).output() + .map(Id()).map_partition(Ext()) \ + .map_partition(Verify([2, 4], "CustomTypeSerialization")).output() #Map class Mapper(MapFunction): def map(self, value): return value * value d1 \ - .map((lambda x: x * x), INT).map(Mapper(), INT) \ - .map_partition(Verify([1, 1296, 20736], "Map"), STRING).output() + .map((lambda x: x * x)).map(Mapper()) \ + .map_partition(Verify([1, 1296, 20736], "Map")).output() #FlatMap class FlatMap(FlatMapFunction): @@ -142,8 +143,8 @@ def flat_map(self, value, collector): collector.collect(value) collector.collect(value * 2) d1 \ - .flat_map(FlatMap(), INT).flat_map(FlatMap(), INT) \ - .map_partition(Verify([1, 2, 2, 4, 6, 12, 12, 24, 12, 24, 24, 48], "FlatMap"), STRING).output() + .flat_map(FlatMap()).flat_map(FlatMap()) \ + .map_partition(Verify([1, 2, 2, 4, 6, 12, 12, 24, 12, 24, 24, 48], "FlatMap")).output() #MapPartition class MapPartition(MapPartitionFunction): @@ -151,8 +152,8 @@ def map_partition(self, iterator, collector): for value in iterator: collector.collect(value * 2) d1 \ - .map_partition(MapPartition(), INT) \ - .map_partition(Verify([2, 12, 24], "MapPartition"), STRING).output() + .map_partition(MapPartition()) \ + .map_partition(Verify([2, 12, 24], "MapPartition")).output() #Filter class Filter(FilterFunction): @@ -164,7 +165,7 @@ def filter(self, value): return value > self.limit d1 \ .filter(Filter(5)).filter(Filter(8)) \ - .map_partition(Verify([12], "Filter"), STRING).output() + .map_partition(Verify([12], "Filter")).output() #Reduce class Reduce(ReduceFunction): @@ -176,7 +177,11 @@ def reduce(self, value1, value2): return (value1[0] + value2[0], value1[1] + value2[1], value1[2], value1[3] or value2[3]) d1 \ .reduce(Reduce()) \ - .map_partition(Verify([19], "AllReduce"), STRING).output() + .map_partition(Verify([19], "AllReduce")).output() + + d4 \ + .group_by(2).reduce(Reduce2()) \ + .map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "GroupedReduce")).output() #GroupReduce class GroupReduce(GroupReduceFunction): @@ -193,9 +198,31 @@ class GroupReduce2(GroupReduceFunction): def reduce(self, iterator, collector): for value in iterator: collector.collect(value) + + d4 \ + .reduce_group(GroupReduce2()) \ + .map_partition(Verify([(1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "AllGroupReduce")).output() + d4 \ + .group_by(lambda x: x[2]).reduce_group(GroupReduce(), combinable=False) \ + .map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "GroupReduceWithKeySelector")).output() d4 \ - .group_by(2).reduce_group(GroupReduce(), (INT, FLOAT, STRING, BOOL), combinable=False) \ - .map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "AllGroupReduce"), STRING).output() + .group_by(2).reduce_group(GroupReduce()) \ + .map_partition(Verify([(3, 1.4, "hello", True), (2, 0.4, "world", False)], "GroupReduce")).output() + d5 \ + .group_by(0).sort_group(1, Order.ASCENDING).reduce_group(GroupReduce2(), combinable=True) \ + .map_partition(Verify([(1, 0.4), (1, 2.4), (1, 3.7), (1, 5.4)], "SortedGroupReduceAsc")).output() + d5 \ + .group_by(0).sort_group(1, Order.DESCENDING).reduce_group(GroupReduce2(), combinable=True) \ + .map_partition(Verify([(1, 5.4), (1, 3.7), (1, 2.4), (1, 0.4)], "SortedGroupReduceDes")).output() + d5 \ + .group_by(lambda x: x[0]).sort_group(1, Order.DESCENDING).reduce_group(GroupReduce2(), combinable=True) \ + .map_partition(Verify([(1, 5.4), (1, 3.7), (1, 2.4), (1, 0.4)], "SortedGroupReduceKeySelG")).output() + d5 \ + .group_by(0).sort_group(lambda x: x[1], Order.DESCENDING).reduce_group(GroupReduce2(), combinable=True) \ + .map_partition(Verify([(1, 5.4), (1, 3.7), (1, 2.4), (1, 0.4)], "SortedGroupReduceKeySelS")).output() + d5 \ + .group_by(lambda x: x[0]).sort_group(lambda x: x[1], Order.DESCENDING).reduce_group(GroupReduce2(), combinable=True) \ + .map_partition(Verify([(1, 5.4), (1, 3.7), (1, 2.4), (1, 0.4)], "SortedGroupReduceKeySelGS")).output() #Execution env.set_parallelism(1) diff --git a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main2.py b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main2.py index 56e325059df33..6bf1fabe9beac 100644 --- a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main2.py +++ b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main2.py @@ -22,7 +22,6 @@ from flink.functions.CrossFunction import CrossFunction from flink.functions.JoinFunction import JoinFunction from flink.functions.CoGroupFunction import CoGroupFunction -from flink.plan.Constants import BOOL, INT, FLOAT, STRING #Utilities @@ -85,28 +84,32 @@ def join(self, value1, value2): else: return value2[0] + str(value1[1]) d2 \ - .join(d3).where(2).equal_to(0).using(Join(), STRING) \ - .map_partition(Verify(["hello1", "world0.4"], "Join"), STRING).output() + .join(d3).where(2).equal_to(0).using(Join()) \ + .map_partition(Verify(["hello1", "world0.4"], "Join")).output() + d2 \ + .join(d3).where(lambda x: x[2]).equal_to(0).using(Join()) \ + .map_partition(Verify(["hello1", "world0.4"], "JoinWithKeySelector")).output() d2 \ .join(d3).where(2).equal_to(0).project_first(0, 3).project_second(0) \ - .map_partition(Verify([(1, True, "hello"), (2, False, "world")], "Project Join"), STRING).output() + .map_partition(Verify([(1, True, "hello"), (2, False, "world")], "Project Join")).output() d2 \ .join(d3).where(2).equal_to(0) \ - .map_partition(Verify([((1, 0.5, "hello", True), ("hello",)), ((2, 0.4, "world", False), ("world",))], "Default Join"), STRING).output() + .map_partition(Verify([((1, 0.5, "hello", True), ("hello",)), ((2, 0.4, "world", False), ("world",))], "Default Join")).output() #Cross class Cross(CrossFunction): def cross(self, value1, value2): return (value1, value2[3]) d1 \ - .cross(d2).using(Cross(), (INT, BOOL)) \ - .map_partition(Verify([(1, True), (1, False), (6, True), (6, False), (12, True), (12, False)], "Cross"), STRING).output() + .cross(d2).using(Cross()) \ + .map_partition(Verify([(1, True), (1, False), (6, True), (6, False), (12, True), (12, False)], "Cross")).output() d1 \ .cross(d3) \ - .map_partition(Verify([(1, ("hello",)), (1, ("world",)), (6, ("hello",)), (6, ("world",)), (12, ("hello",)), (12, ("world",))], "Default Cross"), STRING).output() + .map_partition(Verify([(1, ("hello",)), (1, ("world",)), (6, ("hello",)), (6, ("world",)), (12, ("hello",)), (12, ("world",))], "Default Cross")).output() + d2 \ .cross(d3).project_second(0).project_first(0, 1) \ - .map_partition(Verify([("hello", 1, 0.5), ("world", 1, 0.5), ("hello", 2, 0.4), ("world", 2, 0.4)], "Project Cross"), STRING).output() + .map_partition(Verify([("hello", 1, 0.5), ("world", 1, 0.5), ("hello", 2, 0.4), ("world", 2, 0.4)], "Project Cross")).output() #CoGroup class CoGroup(CoGroupFunction): @@ -114,8 +117,8 @@ def co_group(self, iterator1, iterator2, collector): while iterator1.has_next() and iterator2.has_next(): collector.collect((iterator1.next(), iterator2.next())) d4 \ - .co_group(d5).where(0).equal_to(2).using(CoGroup(), ((INT, FLOAT, STRING, BOOL), (FLOAT, FLOAT, INT))) \ - .map_partition(Verify([((1, 0.5, "hello", True), (4.4, 4.3, 1)), ((1, 0.4, "hello", False), (4.3, 4.4, 1))], "CoGroup"), STRING).output() + .co_group(d5).where(0).equal_to(2).using(CoGroup()) \ + .map_partition(Verify([((1, 0.5, "hello", True), (4.4, 4.3, 1)), ((1, 0.4, "hello", False), (4.3, 4.4, 1))], "CoGroup")).output() #Broadcast class MapperBcv(MapFunction): @@ -123,22 +126,23 @@ def map(self, value): factor = self.context.get_broadcast_variable("test")[0][0] return value * factor d1 \ - .map(MapperBcv(), INT).with_broadcast_set("test", d2) \ - .map_partition(Verify([1, 6, 12], "Broadcast"), STRING).output() + .map(MapperBcv()).with_broadcast_set("test", d2) \ + .map_partition(Verify([1, 6, 12], "Broadcast")).output() #Misc class Mapper(MapFunction): def map(self, value): return value * value d1 \ - .map(Mapper(), INT).map((lambda x: x * x), INT) \ - .map_partition(Verify([1, 1296, 20736], "Chained Lambda"), STRING).output() + .map(Mapper()).map((lambda x: x * x)) \ + .map_partition(Verify([1, 1296, 20736], "Chained Lambda")).output() d2 \ .project(0, 1, 2) \ - .map_partition(Verify([(1, 0.5, "hello"), (2, 0.4, "world")], "Project"), STRING).output() + .map_partition(Verify([(1, 0.5, "hello"), (2, 0.4, "world")], "Project")).output() d2 \ .union(d4) \ - .map_partition(Verify2([(1, 0.5, "hello", True), (2, 0.4, "world", False), (1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "Union"), STRING).output() + .map_partition(Verify2([(1, 0.5, "hello", True), (2, 0.4, "world", False), (1, 0.5, "hello", True), (1, 0.4, "hello", False), (1, 0.5, "hello", True), (2, 0.4, "world", False)], "Union")).output() + #Execution env.set_parallelism(1) diff --git a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_type_deduction.py b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_type_deduction.py deleted file mode 100644 index 1ff3f923bc734..0000000000000 --- a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_type_deduction.py +++ /dev/null @@ -1,73 +0,0 @@ -################################################################################ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -################################################################################ -from flink.plan.Environment import get_environment -from flink.plan.Constants import BOOL, STRING -from flink.functions.MapPartitionFunction import MapPartitionFunction - - -class Verify(MapPartitionFunction): - def __init__(self, msg): - super(Verify, self).__init__() - self.msg = msg - - def map_partition(self, iterator, collector): - if self.msg is None: - return - else: - raise Exception("Type Deduction failed: " + self.msg) - -if __name__ == "__main__": - env = get_environment() - - d1 = env.from_elements(("hello", 4, 3.2, True)) - - d2 = env.from_elements("world") - - direct_from_source = d1.filter(lambda x:True) - - msg = None - - if direct_from_source._info.types != ("hello", 4, 3.2, True): - msg = "Error deducting type directly from source." - - from_common_udf = d1.map(lambda x: x[3], BOOL).filter(lambda x:True) - - if from_common_udf._info.types != BOOL: - msg = "Error deducting type from common udf." - - through_projection = d1.project(3, 2).filter(lambda x:True) - - if through_projection._info.types != (True, 3.2): - msg = "Error deducting type through projection." - - through_default_op = d1.cross(d2).filter(lambda x:True) - - if through_default_op._info.types != (("hello", 4, 3.2, True), "world"): - msg = "Error deducting type through default J/C." +str(through_default_op._info.types) - - through_prj_op = d1.cross(d2).project_first(1, 0).project_second().project_first(3, 2).filter(lambda x:True) - - if through_prj_op._info.types != (4, "hello", "world", True, 3.2): - msg = "Error deducting type through projection J/C. "+str(through_prj_op._info.types) - - - env = get_environment() - - env.from_elements("dummy").map_partition(Verify(msg), STRING).output() - - env.execute(local=True)