Skip to content

Commit

Permalink
[SPARK-3045] Make Serializer interface Java friendly
Browse files Browse the repository at this point in the history
Author: Reynold Xin <[email protected]>

Closes apache#1948 from rxin/kryo and squashes the following commits:

a3a80d8 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer classloader
3d13277 [Reynold Xin] Reverted that in TestJavaSerializerImpl too.
196f3dc [Reynold Xin] Ok one more commit to revert the classloader change.
c49b50c [Reynold Xin] Removed JavaSerializer change.
afbf37d [Reynold Xin] Moved the test case also.
a2e693e [Reynold Xin] Removed the Kryo bug fix from this pull request.
c81bd6c [Reynold Xin] Use defaultClassLoader when executing user specified custom registrator.
68f261e [Reynold Xin] Added license check excludes.
0c28179 [Reynold Xin] [SPARK-3045] Make Serializer interface Java friendly [SPARK-3046] Set executor's class loader as the default serializer class loader
  • Loading branch information
rxin committed Aug 16, 2014
1 parent c9da466 commit a83c772
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,34 +63,35 @@ extends DeserializationStream {
def close() { objIn.close() }
}


private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
extends SerializerInstance {

def serialize[T: ClassTag](t: T): ByteBuffer = {
override def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
out.writeObject(t)
out.close()
ByteBuffer.wrap(bos.toByteArray)
}

def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis)
in.readObject().asInstanceOf[T]
in.readObject()
}

def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis, loader)
in.readObject().asInstanceOf[T]
in.readObject()
}

def serializeStream(s: OutputStream): SerializationStream = {
override def serializeStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s, counterReset)
}

def deserializeStream(s: InputStream): DeserializationStream = {
override def deserializeStream(s: InputStream): DeserializationStream = {
new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class KryoSerializer(conf: SparkConf)
Thread.currentThread.setContextClassLoader(classLoader)
reg.registerClasses(kryo)
} catch {
case e: Exception =>
case e: Exception =>
throw new SparkException(s"Failed to invoke $regCls", e)
} finally {
Thread.currentThread.setContextClassLoader(oldClassLoader)
Expand All @@ -106,7 +106,7 @@ class KryoSerializer(conf: SparkConf)
kryo
}

def newInstance(): SerializerInstance = {
override def newInstance(): SerializerInstance = {
new KryoSerializerInstance(this)
}
}
Expand All @@ -115,20 +115,20 @@ private[spark]
class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
val output = new KryoOutput(outStream)

def writeObject[T: ClassTag](t: T): SerializationStream = {
override def writeObject[T: ClassTag](t: T): SerializationStream = {
kryo.writeClassAndObject(output, t)
this
}

def flush() { output.flush() }
def close() { output.close() }
override def flush() { output.flush() }
override def close() { output.close() }
}

private[spark]
class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
val input = new KryoInput(inStream)
private val input = new KryoInput(inStream)

def readObject[T: ClassTag](): T = {
override def readObject[T: ClassTag](): T = {
try {
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
Expand All @@ -138,31 +138,31 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
}
}

def close() {
override def close() {
// Kryo's Input automatically closes the input stream it is using.
input.close()
}
}

private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
val kryo = ks.newKryo()
private val kryo = ks.newKryo()

// Make these lazy vals to avoid creating a buffer unless we use them
lazy val output = ks.newKryoOutput()
lazy val input = new KryoInput()
private lazy val output = ks.newKryoOutput()
private lazy val input = new KryoInput()

def serialize[T: ClassTag](t: T): ByteBuffer = {
override def serialize[T: ClassTag](t: T): ByteBuffer = {
output.clear()
kryo.writeClassAndObject(output, t)
ByteBuffer.wrap(output.toBytes)
}

def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
input.setBuffer(bytes.array)
kryo.readClassAndObject(input).asInstanceOf[T]
}

def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
input.setBuffer(bytes.array)
Expand All @@ -171,11 +171,11 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
obj
}

def serializeStream(s: OutputStream): SerializationStream = {
override def serializeStream(s: OutputStream): SerializationStream = {
new KryoSerializationStream(kryo, s)
}

def deserializeStream(s: InputStream): DeserializationStream = {
override def deserializeStream(s: InputStream): DeserializationStream = {
new KryoDeserializationStream(kryo, s)
}
}
Expand Down
25 changes: 6 additions & 19 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
* They are intended to be used to serialize/de-serialize data within a single Spark application.
*/
@DeveloperApi
trait Serializer {
abstract class Serializer {

/**
* Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
Expand All @@ -61,10 +61,12 @@ trait Serializer {
this
}

/** Creates a new [[SerializerInstance]]. */
def newInstance(): SerializerInstance
}


@DeveloperApi
object Serializer {
def getSerializer(serializer: Serializer): Serializer = {
if (serializer == null) SparkEnv.get.serializer else serializer
Expand All @@ -81,7 +83,7 @@ object Serializer {
* An instance of a serializer, for use by one thread at a time.
*/
@DeveloperApi
trait SerializerInstance {
abstract class SerializerInstance {
def serialize[T: ClassTag](t: T): ByteBuffer

def deserialize[T: ClassTag](bytes: ByteBuffer): T
Expand All @@ -91,29 +93,14 @@ trait SerializerInstance {
def serializeStream(s: OutputStream): SerializationStream

def deserializeStream(s: InputStream): DeserializationStream

def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = {
// Default implementation uses serializeStream
val stream = new ByteArrayOutputStream()
serializeStream(stream).writeAll(iterator)
val buffer = ByteBuffer.wrap(stream.toByteArray)
buffer.flip()
buffer
}

def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
// Default implementation uses deserializeStream
buffer.rewind()
deserializeStream(new ByteBufferInputStream(buffer)).asIterator
}
}

/**
* :: DeveloperApi ::
* A stream for writing serialized objects.
*/
@DeveloperApi
trait SerializationStream {
abstract class SerializationStream {
def writeObject[T: ClassTag](t: T): SerializationStream
def flush(): Unit
def close(): Unit
Expand All @@ -132,7 +119,7 @@ trait SerializationStream {
* A stream for reading serialized objects.
*/
@DeveloperApi
trait DeserializationStream {
abstract class DeserializationStream {
def readObject[T: ClassTag](): T
def close(): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
/**
* Pluggable serializers for RDD and shuffle data.
*/
package org.apache.spark.serializer;
package org.apache.spark.serializer;
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer;

import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;

import scala.Option;
import scala.reflect.ClassTag;


/**
* A simple Serializer implementation to make sure the API is Java-friendly.
*/
class TestJavaSerializerImpl extends Serializer {

@Override
public SerializerInstance newInstance() {
return null;
}

static class SerializerInstanceImpl extends SerializerInstance {
@Override
public <T> ByteBuffer serialize(T t, ClassTag<T> evidence$1) {
return null;
}

@Override
public <T> T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag<T> evidence$1) {
return null;
}

@Override
public <T> T deserialize(ByteBuffer bytes, ClassTag<T> evidence$1) {
return null;
}

@Override
public SerializationStream serializeStream(OutputStream s) {
return null;
}

@Override
public DeserializationStream deserializeStream(InputStream s) {
return null;
}
}

static class SerializationStreamImpl extends SerializationStream {

@Override
public <T> SerializationStream writeObject(T t, ClassTag<T> evidence$1) {
return null;
}

@Override
public void flush() {

}

@Override
public void close() {

}
}

static class DeserializationStreamImpl extends DeserializationStream {

@Override
public <T> T readObject(ClassTag<T> evidence$1) {
return null;
}

@Override
public void close() {

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer

import org.scalatest.FunSuite

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.LocalSparkContext
import org.apache.spark.SparkException


class KryoSerializerResizableOutputSuite extends FunSuite {

// trial and error showed this will not serialize with 1mb buffer
val x = (1 to 400000).toArray

test("kryo without resizable output buffer should fail on large array") {
val conf = new SparkConf(false)
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.kryoserializer.buffer.mb", "1")
conf.set("spark.kryoserializer.buffer.max.mb", "1")
val sc = new SparkContext("local", "test", conf)
intercept[SparkException](sc.parallelize(x).collect())
LocalSparkContext.stop(sc)
}

test("kryo with resizable output buffer should succeed on large array") {
val conf = new SparkConf(false)
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.kryoserializer.buffer.mb", "1")
conf.set("spark.kryoserializer.buffer.max.mb", "2")
val sc = new SparkContext("local", "test", conf)
assert(sc.parallelize(x).collect() === x)
LocalSparkContext.stop(sc)
}
}
Loading

0 comments on commit a83c772

Please sign in to comment.