From 211d0bdbf956f6bf17ea40f72d6aca9cd4fb2fad Mon Sep 17 00:00:00 2001 From: Stephan Ewen <sewen@apache.org> Date: Fri, 10 Apr 2015 15:02:26 +0200 Subject: [PATCH] [FLINK-1862] [apis] Add support for non-serializable types for collect() by switching from Java serialization to Flink serialization --- .../SerializedListAccumulator.java | 49 +++++++++---------- .../org/apache/flink/api/java/DataSet.java | 18 ++----- .../java/org/apache/flink/api/java/Utils.java | 12 +++-- .../org/apache/flink/api/scala/DataSet.scala | 19 +++---- .../test/classloading/jar/KMeansForTest.java | 3 +- 5 files changed, 45 insertions(+), 56 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/accumulators/SerializedListAccumulator.java b/flink-core/src/main/java/org/apache/flink/api/common/accumulators/SerializedListAccumulator.java index 4ab339b1a6c29..65a8c39bcd597 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/accumulators/SerializedListAccumulator.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/accumulators/SerializedListAccumulator.java @@ -18,8 +18,14 @@ package org.apache.flink.api.common.accumulators; -import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.InputViewDataInputStreamWrapper; +import org.apache.flink.core.memory.OutputViewDataOutputStreamWrapper; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -37,19 +43,24 @@ public class SerializedListAccumulator<T> implements Accumulator<T, ArrayList<by private static final long serialVersionUID = 1L; private ArrayList<byte[]> localValue = new ArrayList<byte[]>(); + @Override public void add(T value) { - if (value == null) { - throw new NullPointerException("Value to accumulate must nor be null"); - } - + throw new UnsupportedOperationException(); + } + + public void add(T value, TypeSerializer<T> serializer) throws IOException { try { - byte[] byteArray = InstantiationUtil.serializeObject(value); - localValue.add(byteArray); + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + OutputViewDataOutputStreamWrapper out = + new OutputViewDataOutputStreamWrapper(new DataOutputStream(outStream)); + + serializer.serialize(value, out); + localValue.add(outStream.toByteArray()); } catch (IOException e) { - throw new RuntimeException("Serialization of accumulated value failed", e); + throw new IOException("Failed to serialize value '" + value + '\'', e); } } @@ -58,21 +69,6 @@ public ArrayList<byte[]> getLocalValue() { return localValue; } - public ArrayList<T> deserializeLocalValue(ClassLoader classLoader) { - try { - ArrayList<T> arrList = new ArrayList<T>(localValue.size()); - for (byte[] byteArr : localValue) { - @SuppressWarnings("unchecked") - T item = (T) InstantiationUtil.deserializeObject(byteArr, classLoader); - arrList.add(item); - } - return arrList; - } - catch (Exception e) { - throw new RuntimeException("Cannot deserialize accumulator list element", e); - } - } - @Override public void resetLocal() { localValue.clear(); @@ -91,12 +87,15 @@ public SerializedListAccumulator<T> clone() { } @SuppressWarnings("unchecked") - public static <T> List<T> deserializeList(ArrayList<byte[]> data, ClassLoader loader) + public static <T> List<T> deserializeList(ArrayList<byte[]> data, TypeSerializer<T> serializer) throws IOException, ClassNotFoundException { List<T> result = new ArrayList<T>(data.size()); for (byte[] bytes : data) { - result.add((T) InstantiationUtil.deserializeObject(bytes, loader)); + ByteArrayInputStream inStream = new ByteArrayInputStream(bytes); + InputViewDataInputStreamWrapper in = new InputViewDataInputStreamWrapper(new DataInputStream(inStream)); + T val = serializer.deserialize(in); + result.add(val); } return result; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java index 7cee323bef4b6..a6a0af8877fdb 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java @@ -42,6 +42,7 @@ import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint; import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint; import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.aggregation.Aggregations; import org.apache.flink.api.java.functions.FirstReducer; import org.apache.flink.api.java.functions.FormattingMapper; @@ -411,24 +412,15 @@ public long count() throws Exception { * @see org.apache.flink.api.java.Utils.CollectHelper */ public List<T> collect() throws Exception { - // validate that our type is actually serializable - Class<?> typeClass = getType().getTypeClass(); - ClassLoader cl = typeClass.getClassLoader() == null ? ClassLoader.getSystemClassLoader() - : typeClass.getClassLoader(); - - if (!java.io.Serializable.class.isAssignableFrom(typeClass)) { - throw new UnsupportedOperationException("collect() can only be used with serializable data types. " - + "The DataSet type '" + typeClass.getName() + "' does not implement java.io.Serializable."); - } - final String id = new AbstractID().toString(); - - this.flatMap(new Utils.CollectHelper<T>(id)).output(new DiscardingOutputFormat<T>()); + final TypeSerializer<T> serializer = getType().createSerializer(getExecutionEnvironment().getConfig()); + + this.flatMap(new Utils.CollectHelper<T>(id, serializer)).output(new DiscardingOutputFormat<T>()); JobExecutionResult res = getExecutionEnvironment().execute(); ArrayList<byte[]> accResult = res.getAccumulatorResult(id); try { - return SerializedListAccumulator.deserializeList(accResult, cl); + return SerializedListAccumulator.deserializeList(accResult, serializer); } catch (ClassNotFoundException e) { throw new RuntimeException("Cannot find type class of collected data type.", e); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/Utils.java b/flink-java/src/main/java/org/apache/flink/api/java/Utils.java index 5351484d72f67..38b24a2a84092 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/Utils.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/Utils.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.accumulators.SerializedListAccumulator; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import java.lang.reflect.Field; @@ -97,21 +98,24 @@ public static class CollectHelper<T> extends RichFlatMapFunction<T, T> { private static final long serialVersionUID = 1L; private final String id; - private final SerializedListAccumulator<T> accumulator; + private final TypeSerializer<T> serializer; + + private SerializedListAccumulator<T> accumulator; - public CollectHelper(String id) { + public CollectHelper(String id, TypeSerializer<T> serializer) { this.id = id; - this.accumulator = new SerializedListAccumulator<T>(); + this.serializer = serializer; } @Override public void open(Configuration parameters) throws Exception { + this.accumulator = new SerializedListAccumulator<T>(); getRuntimeContext().addAccumulator(id, accumulator); } @Override public void flatMap(T value, Collector<T> out) throws Exception { - accumulator.add(value); + accumulator.add(value, serializer); } } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 48a5285d43848..3b80a23f60416 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -26,6 +26,7 @@ import org.apache.flink.api.common.functions._ import org.apache.flink.api.common.io.{FileOutputFormat, OutputFormat} import org.apache.flink.api.common.operators.Order import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod +import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.api.java.Utils.CountHelper import org.apache.flink.api.java.aggregation.Aggregations import org.apache.flink.api.java.functions.{FirstReducer, KeySelector} @@ -537,24 +538,18 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { */ @throws(classOf[Exception]) def collect(): Seq[T] = { - val typeClass: Class[_] = getType().getTypeClass() - val cl: ClassLoader = if (typeClass.getClassLoader == null) ClassLoader.getSystemClassLoader - else typeClass.getClassLoader - - if (typeClass != null && !classOf[java.io.Serializable].isAssignableFrom(typeClass)) { - throw new UnsupportedOperationException( - "collect() can only be used with serializable data types. " + - "The DataSet type '" + typeClass.getName + "' does not implement java.io.Serializable.") - } - val id = new AbstractID().toString - javaSet.flatMap(new Utils.CollectHelper[T](id)).output(new DiscardingOutputFormat[T]) + val serializer = getType().createSerializer(getExecutionEnvironment.getConfig) + + javaSet.flatMap(new Utils.CollectHelper[T](id, serializer)) + .output(new DiscardingOutputFormat[T]) + val res = getExecutionEnvironment.execute() val accResult: java.util.ArrayList[Array[Byte]] = res.getAccumulatorResult(id) try { - SerializedListAccumulator.deserializeList(accResult, cl).asScala + SerializedListAccumulator.deserializeList(accResult, serializer).asScala } catch { case e: ClassNotFoundException => { diff --git a/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/KMeansForTest.java b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/KMeansForTest.java index 083b2bcad0ca2..794efbd36ea35 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/KMeansForTest.java +++ b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/KMeansForTest.java @@ -28,7 +28,6 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; -import java.io.Serializable; import java.util.Collection; /** @@ -107,7 +106,7 @@ public static void main(String[] args) throws Exception { /** * A simple two-dimensional point. */ - public static class Point implements Serializable { + public static class Point { public double x, y;