Skip to content

Commit

Permalink
CORDA-973 Compression support for serialization (corda#2473)
Browse files Browse the repository at this point in the history
* Serialization magic is now 7 bytes
* Introduce encoding property and whitelist
  • Loading branch information
andr3ej authored Feb 23, 2018
1 parent 2af0fee commit c8672d3
Show file tree
Hide file tree
Showing 26 changed files with 436 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .ci/api-current.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2817,7 +2817,7 @@ public final class net.corda.core.serialization.ObjectWithCompatibleContext exte
public final class net.corda.core.serialization.SerializationAPIKt extends java.lang.Object
@org.jetbrains.annotations.NotNull public static final net.corda.core.serialization.SerializedBytes serialize(Object, net.corda.core.serialization.SerializationFactory, net.corda.core.serialization.SerializationContext)
##
public interface net.corda.core.serialization.SerializationContext
@net.corda.core.DoNotImplement public interface net.corda.core.serialization.SerializationContext
@org.jetbrains.annotations.NotNull public abstract ClassLoader getDeserializationClassLoader()
public abstract boolean getObjectReferencesEnabled()
@org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.ByteSequence getPreferredSerializationVersion()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package net.corda.core.serialization

import net.corda.core.DoNotImplement
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256
import net.corda.core.serialization.internal.effectiveSerializationEnv
Expand Down Expand Up @@ -99,14 +100,22 @@ abstract class SerializationFactory {
}
}
typealias SerializationMagic = ByteSequence
@DoNotImplement
interface SerializationEncoding

/**
* Parameters to serialization and deserialization.
*/
@DoNotImplement
interface SerializationContext {
/**
* When serializing, use the format this header sequence represents.
*/
val preferredSerializationVersion: SerializationMagic
/**
* If non-null, apply this encoding (typically compression) when serializing.
*/
val encoding: SerializationEncoding?
/**
* The class loader to use for deserialization.
*/
Expand All @@ -115,6 +124,10 @@ interface SerializationContext {
* A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized.
*/
val whitelist: ClassWhitelist
/**
* A whitelist that determines (mostly for security purposes) whether a particular encoding may be used when deserializing.
*/
val encodingWhitelist: EncodingWhitelist
/**
* A map of any addition properties specific to the particular use case.
*/
Expand Down Expand Up @@ -161,6 +174,11 @@ interface SerializationContext {
*/
fun withPreferredSerializationVersion(magic: SerializationMagic): SerializationContext

/**
* A shallow copy of this context but with the given (possibly null) encoding.
*/
fun withEncoding(encoding: SerializationEncoding?): SerializationContext

/**
* The use case that we are serializing for, since it influences the implementations chosen.
*/
Expand Down Expand Up @@ -232,3 +250,8 @@ class SerializedBytes<T : Any>(bytes: ByteArray) : OpaqueBytes(bytes) {
interface ClassWhitelist {
fun hasListed(type: Class<*>): Boolean
}

@DoNotImplement
interface EncodingWhitelist {
fun acceptEncoding(encoding: SerializationEncoding): Boolean
}
3 changes: 3 additions & 0 deletions node-api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ dependencies {
// For AMQP serialisation.
compile "org.apache.qpid:proton-j:0.21.0"

// Pure-Java Snappy compression
compile 'org.iq80.snappy:snappy:0.4'

// Unit testing helpers.
testCompile "junit:junit:$junit_version"
testCompile "org.assertj:assertj-core:$assertj_version"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ val KRYO_RPC_CLIENT_CONTEXT = SerializationContextImpl(kryoMagic,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.RPCClient)
SerializationContext.UseCase.RPCClient,
null)
val AMQP_RPC_CLIENT_CONTEXT = SerializationContextImpl(amqpMagic,
SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.RPCClient)
SerializationContext.UseCase.RPCClient,
null)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package net.corda.nodeapi.internal.serialization

import java.io.EOFException
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer

class OrdinalBits(private val ordinal: Int) {
interface OrdinalWriter {
val bits: OrdinalBits
val encodedSize get() = 1
fun writeTo(stream: OutputStream) = stream.write(bits.ordinal)
fun putTo(buffer: ByteBuffer) = buffer.put(bits.ordinal.toByte())!!
}

init {
require(ordinal >= 0) { "The ordinal must be non-negative." }
require(ordinal < 128) { "Consider implementing a varint encoding." }
}
}

class OrdinalReader<out E : Any>(private val values: Array<E>) {
private val enumName = values[0].javaClass.simpleName
private val range = 0 until values.size
fun readFrom(stream: InputStream): E {
val ordinal = stream.read()
if (ordinal == -1) throw EOFException("Expected a $enumName ordinal.")
if (ordinal !in range) throw NoSuchElementException("No $enumName with ordinal: $ordinal")
return values[ordinal]
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,58 @@
package net.corda.nodeapi.internal.serialization

import net.corda.core.internal.VisibleForTesting
import net.corda.core.serialization.SerializationEncoding
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.nodeapi.internal.serialization.OrdinalBits.OrdinalWriter
import org.iq80.snappy.SnappyFramedInputStream
import org.iq80.snappy.SnappyFramedOutputStream
import java.io.OutputStream
import java.io.InputStream
import java.nio.ByteBuffer
import java.util.zip.DeflaterOutputStream
import java.util.zip.InflaterInputStream

class CordaSerializationMagic(bytes: ByteArray) : OpaqueBytes(bytes) {
private val bufferView = slice()
fun consume(data: ByteSequence): ByteBuffer? {
return if (data.slice(end = size) == bufferView) data.slice(size) else null
}
}

enum class SectionId : OrdinalWriter {
/** Serialization data follows, and then discard the rest of the stream (if any) as legacy data may have trailing garbage. */
DATA_AND_STOP,
/** Identical behaviour to [DATA_AND_STOP], historically used for Kryo. Do not use in new code. */
ALT_DATA_AND_STOP,
/** The ordinal of a [CordaSerializationEncoding] follows, which should be used to decode the remainder of the stream. */
ENCODING;

companion object {
val reader = OrdinalReader(values())
}

override val bits = OrdinalBits(ordinal)
}

enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
DEFLATE {
override fun wrap(stream: OutputStream) = DeflaterOutputStream(stream)
override fun wrap(stream: InputStream) = InflaterInputStream(stream)
},
SNAPPY {
override fun wrap(stream: OutputStream) = SnappyFramedOutputStream(stream)
override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false)
};

companion object {
val reader = OrdinalReader(values())
}

override val bits = OrdinalBits(ordinal)
abstract fun wrap(stream: OutputStream): OutputStream
abstract fun wrap(stream: InputStream): InputStream
}

@VisibleForTesting
internal val encodingNotPermittedFormat = "Encoding not permitted: %s"
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ import java.util.concurrent.ExecutionException

val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled"

data class SerializationContextImpl(override val preferredSerializationVersion: SerializationMagic,
override val deserializationClassLoader: ClassLoader,
override val whitelist: ClassWhitelist,
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase) : SerializationContext {
internal object NullEncodingWhitelist : EncodingWhitelist {
override fun acceptEncoding(encoding: SerializationEncoding) = false
}

data class SerializationContextImpl @JvmOverloads constructor(override val preferredSerializationVersion: SerializationMagic,
override val deserializationClassLoader: ClassLoader,
override val whitelist: ClassWhitelist,
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase,
override val encoding: SerializationEncoding?,
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : SerializationContext {
private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build()

/**
Expand Down Expand Up @@ -70,6 +75,7 @@ data class SerializationContextImpl(override val preferredSerializationVersion:
}

override fun withPreferredSerializationVersion(magic: SerializationMagic) = copy(preferredSerializationVersion = magic)
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
}

open class SerializationFactoryImpl : SerializationFactory() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,26 @@ val KRYO_RPC_SERVER_CONTEXT = SerializationContextImpl(kryoMagic,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.RPCServer)
SerializationContext.UseCase.RPCServer,
null)
val KRYO_STORAGE_CONTEXT = SerializationContextImpl(kryoMagic,
SerializationDefaults.javaClass.classLoader,
AllButBlacklisted,
emptyMap(),
true,
SerializationContext.UseCase.Storage)
SerializationContext.UseCase.Storage,
null)
val AMQP_STORAGE_CONTEXT = SerializationContextImpl(amqpMagic,
SerializationDefaults.javaClass.classLoader,
AllButBlacklisted,
emptyMap(),
true,
SerializationContext.UseCase.Storage)
SerializationContext.UseCase.Storage,
null)
val AMQP_RPC_SERVER_CONTEXT = SerializationContextImpl(amqpMagic,
SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.RPCServer)
SerializationContext.UseCase.RPCServer,
null)
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ val KRYO_P2P_CONTEXT = SerializationContextImpl(kryoMagic,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.P2P)
SerializationContext.UseCase.P2P,
null)
val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(kryoMagic,
SerializationDefaults.javaClass.classLoader,
QuasarWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Checkpoint)
SerializationContext.UseCase.Checkpoint,
null)
val AMQP_P2P_CONTEXT = SerializationContextImpl(amqpMagic,
SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.P2P)


SerializationContext.UseCase.P2P,
null)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package net.corda.nodeapi.internal.serialization.amqp

import com.esotericsoftware.kryo.io.ByteBufferInputStream
import net.corda.nodeapi.internal.serialization.kryo.ByteBufferOutputStream
import net.corda.nodeapi.internal.serialization.kryo.serializeOutputStreamPool
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer

fun InputStream.asByteBuffer(): ByteBuffer {
return if (this is ByteBufferInputStream) {
byteBuffer // BBIS has no other state, so this is perfectly safe.
} else {
ByteBuffer.wrap(serializeOutputStreamPool.run {
copyTo(it)
it.toByteArray()
})
}
}

fun <T> OutputStream.alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T {
return if (this is ByteBufferOutputStream) {
alsoAsByteBuffer(remaining, task)
} else {
serializeOutputStreamPool.run {
val result = it.alsoAsByteBuffer(remaining, task)
it.copyTo(this)
result
}
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
package net.corda.nodeapi.internal.serialization.amqp

import com.esotericsoftware.kryo.io.ByteBufferInputStream
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.getStackTraceAsString
import net.corda.core.serialization.EncodingWhitelist
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.internal.serialization.CordaSerializationEncoding
import net.corda.nodeapi.internal.serialization.NullEncodingWhitelist
import net.corda.nodeapi.internal.serialization.SectionId
import net.corda.nodeapi.internal.serialization.encodingNotPermittedFormat
import org.apache.qpid.proton.amqp.Binary
import org.apache.qpid.proton.amqp.DescribedType
import org.apache.qpid.proton.amqp.UnsignedByte
import org.apache.qpid.proton.amqp.UnsignedInteger
import org.apache.qpid.proton.codec.Data
import java.io.InputStream
import java.io.NotSerializableException
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.lang.reflect.TypeVariable
import java.lang.reflect.WildcardType
import java.nio.ByteBuffer

data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)

Expand All @@ -22,7 +31,8 @@ data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
* @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple
* instances and threads.
*/
class DeserializationInput(internal val serializerFactory: SerializerFactory) {
class DeserializationInput @JvmOverloads constructor(private val serializerFactory: SerializerFactory,
private val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) {
private val objectHistory: MutableList<Any> = mutableListOf()

internal companion object {
Expand All @@ -47,6 +57,28 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
}
return size + BYTES_NEEDED_TO_PEEK
}

@VisibleForTesting
@Throws(NotSerializableException::class)
internal fun <T> withDataBytes(byteSequence: ByteSequence, encodingWhitelist: EncodingWhitelist, task: (ByteBuffer) -> T): T {
// Check that the lead bytes match expected header
val amqpSequence = amqpMagic.consume(byteSequence) ?: throw NotSerializableException("Serialization header does not match.")
var stream: InputStream = ByteBufferInputStream(amqpSequence)
try {
while (true) {
when (SectionId.reader.readFrom(stream)) {
SectionId.ENCODING -> {
val encoding = CordaSerializationEncoding.reader.readFrom(stream)
encodingWhitelist.acceptEncoding(encoding) || throw NotSerializableException(encodingNotPermittedFormat.format(encoding))
stream = encoding.wrap(stream)
}
SectionId.DATA_AND_STOP, SectionId.ALT_DATA_AND_STOP -> return task(stream.asByteBuffer())
}
}
} finally {
stream.close()
}
}
}

@Throws(NotSerializableException::class)
Expand All @@ -58,12 +90,12 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {

@Throws(NotSerializableException::class)
internal fun getEnvelope(byteSequence: ByteSequence): Envelope {
// Check that the lead bytes match expected header
val dataBytes = amqpMagic.consume(byteSequence) ?: throw NotSerializableException("Serialization header does not match.")
val data = Data.Factory.create()
val expectedSize = dataBytes.remaining()
if (data.decode(dataBytes) != expectedSize.toLong()) throw NotSerializableException("Unexpected size of data")
return Envelope.get(data)
return withDataBytes(byteSequence, encodingWhitelist) { dataBytes ->
val data = Data.Factory.create()
val expectedSize = dataBytes.remaining()
if (data.decode(dataBytes) != expectedSize.toLong()) throw NotSerializableException("Unexpected size of data")
Envelope.get(data)
}
}

@Throws(NotSerializableException::class)
Expand Down
Loading

0 comments on commit c8672d3

Please sign in to comment.