Skip to content

Commit

Permalink
Introduce current context concept for serialization in preparation fo…
Browse files Browse the repository at this point in the history
…r WireTransaction changes (corda#1448)
  • Loading branch information
rick-r3 authored Sep 8, 2017
1 parent 6bf2871 commit 79f1e1a
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package net.corda.core.serialization
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256
import net.corda.core.internal.WriteOnceProperty
import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.sequence
Expand All @@ -13,23 +11,79 @@ import net.corda.core.utilities.sequence
* An abstraction for serializing and deserializing objects, with support for versioning of the wire format via
* a header / prefix in the bytes.
*/
interface SerializationFactory {
abstract class SerializationFactory {
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T
abstract fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T

/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T>
abstract fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T>

/**
* If there is a need to nest serialization/deserialization with a modified context during serialization or deserialization,
* this will return the current context used to start serialization/deserialization.
*/
val currentContext: SerializationContext? get() = _currentContext.get()

/**
* A context to use as a default if you do not require a specially configured context. It will be the current context
* if the use is somehow nested (see [currentContext]).
*/
val defaultContext: SerializationContext get() = currentContext ?: SerializationDefaults.P2P_CONTEXT

private val _currentContext = ThreadLocal<SerializationContext?>()

/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: SerializationContext?, block: () -> T): T {
val priorContext = _currentContext.get()
if (context != null) _currentContext.set(context)
try {
return block()
} finally {
if (context != null) _currentContext.set(priorContext)
}
}

/**
* Allow subclasses to temporarily mark themselves as the current factory for the current thread during serialization/deserialization.
* Will restore the prior context on exiting the block.
*/
protected fun <T> asCurrent(block: SerializationFactory.() -> T): T {
val priorContext = _currentFactory.get()
_currentFactory.set(this)
try {
return block()
} finally {
_currentFactory.set(priorContext)
}
}

companion object {
private val _currentFactory = ThreadLocal<SerializationFactory?>()

/**
* A default factory for serialization/deserialization, taking into account the [currentFactory] if set.
*/
val defaultFactory: SerializationFactory get() = currentFactory ?: SerializationDefaults.SERIALIZATION_FACTORY

/**
* If there is a need to nest serialization/deserialization with a modified context during serialization or deserialization,
* this will return the current factory used to start serialization/deserialization.
*/
val currentFactory: SerializationFactory? get() = _currentFactory.get()
}
}

/**
Expand Down Expand Up @@ -76,6 +130,12 @@ interface SerializationContext {
*/
fun withClassLoader(classLoader: ClassLoader): SerializationContext

/**
* Helper method to return a new context based on this context with the appropriate class loader constructed from the passed attachment identifiers.
* (Requires the attachment storage to have been enabled).
*/
fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext

/**
* Helper method to return a new context based on this context with the given class specifically whitelisted.
*/
Expand Down Expand Up @@ -107,26 +167,26 @@ object SerializationDefaults {
/**
* Convenience extension method for deserializing a ByteSequence, utilising the defaults.
*/
inline fun <reified T : Any> ByteSequence.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T {
inline fun <reified T : Any> ByteSequence.deserialize(serializationFactory: SerializationFactory = SerializationFactory.defaultFactory, context: SerializationContext = serializationFactory.defaultContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}

/**
* Convenience extension method for deserializing SerializedBytes with type matching, utilising the defaults.
*/
inline fun <reified T : Any> SerializedBytes<T>.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T {
inline fun <reified T : Any> SerializedBytes<T>.deserialize(serializationFactory: SerializationFactory = SerializationFactory.defaultFactory, context: SerializationContext = serializationFactory.defaultContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}

/**
* Convenience extension method for deserializing a ByteArray, utilising the defaults.
*/
inline fun <reified T : Any> ByteArray.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T = this.sequence().deserialize(serializationFactory, context)
inline fun <reified T : Any> ByteArray.deserialize(serializationFactory: SerializationFactory = SerializationFactory.defaultFactory, context: SerializationContext = serializationFactory.defaultContext): T = this.sequence().deserialize(serializationFactory, context)

/**
* Convenience extension method for serializing an object of type T, utilising the defaults.
*/
fun <T : Any> T.serialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): SerializedBytes<T> {
fun <T : Any> T.serialize(serializationFactory: SerializationFactory = SerializationFactory.defaultFactory, context: SerializationContext = serializationFactory.defaultContext): SerializedBytes<T> {
return serializationFactory.serialize(this, context)
}

Expand All @@ -142,4 +202,4 @@ class SerializedBytes<T : Any>(bytes: ByteArray) : OpaqueBytes(bytes) {

interface ClassWhitelist {
fun hasListed(type: Class<*>): Boolean
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import net.corda.core.contracts.*
import net.corda.core.crypto.*
import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.serialize
import java.nio.ByteBuffer
import java.util.function.Predicate
Expand All @@ -22,12 +22,12 @@ fun <T : Any> serializedHash(x: T, privacySalt: PrivacySalt?, index: Int): Secur

fun <T : Any> serializedHash(x: T, nonce: SecureHash): SecureHash {
return if (x !is PrivacySalt) // PrivacySalt is not required to have an accompanied nonce.
(x.serialize(context = P2P_CONTEXT.withoutReferences()).bytes + nonce.bytes).sha256()
(x.serialize(context = SerializationFactory.defaultFactory.defaultContext.withoutReferences()).bytes + nonce.bytes).sha256()
else
serializedHash(x)
}

fun <T : Any> serializedHash(x: T): SecureHash = x.serialize(context = P2P_CONTEXT.withoutReferences()).bytes.sha256()
fun <T : Any> serializedHash(x: T): SecureHash = x.serialize(context = SerializationFactory.defaultFactory.defaultContext.withoutReferences()).bytes.sha256()

/** The nonce is computed as Hash(privacySalt || index). */
fun computeNonce(privacySalt: PrivacySalt, index: Int) = (privacySalt.bytes + ByteBuffer.allocate(4).putInt(index).array()).sha256()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ import com.esotericsoftware.kryo.serializers.CompatibleFieldSerializer
import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.MapReferenceResolver
import net.corda.core.contracts.*
import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.CompositeKey
import net.corda.core.identity.Party
import net.corda.core.internal.VisibleForTesting
import net.corda.core.serialization.AttachmentsClassLoader
import net.corda.core.serialization.MissingAttachmentsException
import net.corda.core.serialization.SerializeAsTokenContext
Expand Down Expand Up @@ -241,9 +240,6 @@ fun Input.readBytesWithLength(): ByteArray {
/** A serialisation engine that knows how to deserialise code inside a sandbox */
@ThreadSafe
object WireTransactionSerializer : Serializer<WireTransaction>() {
@VisibleForTesting
internal val attachmentsClassLoaderEnabled = "attachments.class.loader.enabled"

override fun write(kryo: Kryo, output: Output, obj: WireTransaction) {
kryo.writeClassAndObject(output, obj.inputs)
kryo.writeClassAndObject(output, obj.attachments)
Expand All @@ -255,7 +251,7 @@ object WireTransactionSerializer : Serializer<WireTransaction>() {
}

private fun attachmentsClassLoader(kryo: Kryo, attachmentHashes: List<SecureHash>): ClassLoader? {
kryo.context[attachmentsClassLoaderEnabled] as? Boolean ?: false || return null
kryo.context[attachmentsClassLoaderEnabledPropertyName] as? Boolean ?: false || return null
val serializationContext = kryo.serializationContext() ?: return null // Some tests don't set one.
val missing = ArrayList<SecureHash>()
val attachments = ArrayList<Attachment>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import net.corda.core.contracts.Attachment
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.LazyPool
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
Expand All @@ -17,6 +21,8 @@ import java.io.NotSerializableException
import java.util.*
import java.util.concurrent.ConcurrentHashMap

val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled"

object NotSupportedSeralizationScheme : SerializationScheme {
private fun doThrow(): Nothing = throw UnsupportedOperationException("Serialization scheme not supported.")

Expand All @@ -33,6 +39,24 @@ data class SerializationContextImpl(override val preferredSerializationVersion:
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase) : SerializationContext {

private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build()

// We need to cache the AttachmentClassLoaders to avoid too many contexts, since the class loader is part of cache key for the context.
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext {
properties[attachmentsClassLoaderEnabledPropertyName] as? Boolean ?: false || return this
val serializationContext = properties[serializationContextKey] as? SerializeAsTokenContextImpl ?: return this // Some tests don't set one.
return withClassLoader(cache.get(attachmentHashes) {
val missing = ArrayList<SecureHash>()
val attachments = ArrayList<Attachment>()
attachmentHashes.forEach { id ->
serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } ?: run { missing += id }
}
missing.isNotEmpty() && throw MissingAttachmentsException(missing)
AttachmentsClassLoader(attachments)
})
}

override fun withProperty(property: Any, value: Any): SerializationContext {
return copy(properties = properties + (property to value))
}
Expand All @@ -56,7 +80,7 @@ data class SerializationContextImpl(override val preferredSerializationVersion:

private const val HEADER_SIZE: Int = 8

open class SerializationFactoryImpl : SerializationFactory {
open class SerializationFactoryImpl : SerializationFactory() {
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()

private val registeredSchemes: MutableCollection<SerializationScheme> = Collections.synchronizedCollection(mutableListOf())
Expand All @@ -75,10 +99,12 @@ open class SerializationFactoryImpl : SerializationFactory {
}

@Throws(NotSerializableException::class)
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T = schemeFor(byteSequence, context.useCase).deserialize(byteSequence, clazz, context)
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
return asCurrent { withCurrentContext(context) { schemeFor(byteSequence, context.useCase).deserialize(byteSequence, clazz, context) } }
}

override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
return schemeFor(context.preferredSerializationVersion, context.useCase).serialize(obj, context)
return asCurrent { withCurrentContext(context) { schemeFor(context.preferredSerializationVersion, context.useCase).serialize(obj, context) } }
}

fun registerScheme(scheme: SerializationScheme) {
Expand Down
Loading

0 comments on commit 79f1e1a

Please sign in to comment.