Skip to content

Commit

Permalink
[CORDA-2390] - Add whitelists and custom serializers from cordapps to…
Browse files Browse the repository at this point in the history
… serialization … (corda#4551)

* Add whitelists and custom serializers from cordapps to serialization context

* Remove changes in TransactionBuilder, add caching

* Add whitelists and custom serializers from cordapps to serialization context

* Remove changes in TransactionBuilder, add caching

* Address comments

* Increase node memory for SIMM integration test

* Cache only serialization context

* Increase integ test timeout

* Fix API breakage

* Increase max heap size for web server integ test

* Move classloading utils from separate module to core.internal

* Adjust heap size for more integ tests

* Increase time window for IRS demo transactions

* Fix determinator

* Add parameter in core-deterministic

* Stub out class-loading method for DJVM
  • Loading branch information
dimosr authored and Gavin Thomas committed Jan 13, 2019
1 parent 36cd9b9 commit 5b34020
Show file tree
Hide file tree
Showing 37 changed files with 278 additions and 124 deletions.
21 changes: 17 additions & 4 deletions client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package net.corda.client.rpc
import com.github.benmanes.caffeine.cache.Caffeine
import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.core.internal.loadClassesImplementing
import net.corda.core.context.Actor
import net.corda.core.context.Trace
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.PLATFORM_VERSION
import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days
Expand All @@ -19,6 +22,7 @@ import net.corda.serialization.internal.AMQP_RPC_CLIENT_CONTEXT
import net.corda.serialization.internal.amqp.SerializationFactoryCacheKey
import net.corda.serialization.internal.amqp.SerializerFactory
import java.time.Duration
import java.util.ServiceLoader

/**
* This class is essentially just a wrapper for an RPCConnection<CordaRPCOps> and can be treated identically.
Expand Down Expand Up @@ -240,6 +244,7 @@ open class CordaRPCClientConfiguration @JvmOverloads constructor(
* @param configuration An optional configuration used to tweak client behaviour.
* @param sslConfiguration An optional [ClientRpcSslOptions] used to enable secure communication with the server.
* @param haAddressPool A list of [NetworkHostAndPort] representing the addresses of servers in HA mode.
* @param classLoader a classloader, which will be used (if provided) to discover available [SerializationCustomSerializer]s and [SerializationWhitelist]s
* The client will attempt to connect to a live server by trying each address in the list. If the servers are not in
* HA mode, the client will round-robin from the beginning of the list and try all servers.
*/
Expand All @@ -252,8 +257,9 @@ class CordaRPCClient private constructor(
) {
@JvmOverloads
constructor(hostAndPort: NetworkHostAndPort,
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT)
: this(hostAndPort, configuration, null)
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT,
classLoader: ClassLoader? = null)
: this(hostAndPort, configuration, null, classLoader = classLoader)

/**
* @param haAddressPool A list of [NetworkHostAndPort] representing the addresses of servers in HA mode.
Expand Down Expand Up @@ -287,7 +293,7 @@ class CordaRPCClient private constructor(
sslConfiguration: ClientRpcSslOptions? = null,
classLoader: ClassLoader? = null
): CordaRPCClient {
return CordaRPCClient(hostAndPort, configuration, sslConfiguration, classLoader)
return CordaRPCClient(hostAndPort, configuration, sslConfiguration, classLoader = classLoader)
}
}

Expand All @@ -296,7 +302,14 @@ class CordaRPCClient private constructor(
effectiveSerializationEnv
} catch (e: IllegalStateException) {
try {
AMQPClientSerializationScheme.initialiseSerialization(classLoader, Caffeine.newBuilder().maximumSize(128).build<SerializationFactoryCacheKey, SerializerFactory>().asMap())
// If the client has provided a classloader, the associated classpath is checked for available custom serializers and serialization whitelists.
if (classLoader != null) {
val customSerializers = loadClassesImplementing(classLoader, SerializationCustomSerializer::class.java)
val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet()
AMQPClientSerializationScheme.initialiseSerialization(classLoader, customSerializers, serializationWhitelists, Caffeine.newBuilder().maximumSize(128).build<SerializationFactoryCacheKey, SerializerFactory>().asMap())
} else {
AMQPClientSerializationScheme.initialiseSerialization(classLoader, serializerFactoriesForContexts = Caffeine.newBuilder().maximumSize(128).build<SerializationFactoryCacheKey, SerializerFactory>().asMap())
}
} catch (e: IllegalStateException) {
// Race e.g. two of these constructed in parallel, ignore.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import net.corda.core.internal.toSynchronised
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.serialization.internal.*
Expand All @@ -17,24 +18,25 @@ import net.corda.serialization.internal.amqp.custom.RxNotificationSerializer
*/
class AMQPClientSerializationScheme(
cordappCustomSerializers: Set<SerializationCustomSerializer<*,*>>,
cordappSerializationWhitelists: Set<SerializationWhitelist>,
serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>
) : AbstractAMQPSerializationScheme(cordappCustomSerializers, serializerFactoriesForContexts) {
constructor(cordapps: List<Cordapp>) : this(cordapps.customSerializers, AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised())
constructor(cordapps: List<Cordapp>, serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>) : this(cordapps.customSerializers, serializerFactoriesForContexts)
) : AbstractAMQPSerializationScheme(cordappCustomSerializers, cordappSerializationWhitelists, serializerFactoriesForContexts) {
constructor(cordapps: List<Cordapp>) : this(cordapps.customSerializers, cordapps.serializationWhitelists, AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised())
constructor(cordapps: List<Cordapp>, serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>) : this(cordapps.customSerializers, cordapps.serializationWhitelists, serializerFactoriesForContexts)

@Suppress("UNUSED")
constructor() : this(emptySet(), AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised())
constructor() : this(emptySet(), emptySet(), AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised())

companion object {
/** Call from main only. */
fun initialiseSerialization(classLoader: ClassLoader? = null, serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory> = AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised()) {
nodeSerializationEnv = createSerializationEnv(classLoader, serializerFactoriesForContexts)
fun initialiseSerialization(classLoader: ClassLoader? = null, customSerializers: Set<SerializationCustomSerializer<*, *>> = emptySet(), serializationWhitelists: Set<SerializationWhitelist> = emptySet(), serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory> = AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised()) {
nodeSerializationEnv = createSerializationEnv(classLoader, customSerializers, serializationWhitelists, serializerFactoriesForContexts)
}

fun createSerializationEnv(classLoader: ClassLoader? = null, serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory> = AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised()): SerializationEnvironment {
fun createSerializationEnv(classLoader: ClassLoader? = null, customSerializers: Set<SerializationCustomSerializer<*, *>> = emptySet(), serializationWhitelists: Set<SerializationWhitelist> = emptySet(), serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory> = AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised()): SerializationEnvironment {
return SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList(), serializerFactoriesForContexts))
registerScheme(AMQPClientSerializationScheme(customSerializers, serializationWhitelists, serializerFactoriesForContexts))
},
storageContext = AMQP_STORAGE_CONTEXT,
p2pContext = if (classLoader != null) AMQP_P2P_CONTEXT.withClassLoader(classLoader) else AMQP_P2P_CONTEXT,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package net.corda.core.internal

/**
* Stubbing out non-deterministic method.
*/
fun <T: Any> loadClassesImplementing(classloader: ClassLoader, clazz: Class<T>): Set<T> {
return emptySet()
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.P2P
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.serialization.internal.*
Expand Down Expand Up @@ -57,15 +58,16 @@ class LocalSerializationRule(private val label: String) : TestRule {

private fun createTestSerializationEnv(): SerializationEnvironment {
val factory = SerializationFactoryImpl(mutableMapOf()).apply {
registerScheme(AMQPSerializationScheme(emptySet(), AccessOrderLinkedHashMap(128)))
registerScheme(AMQPSerializationScheme(emptySet(), emptySet(), AccessOrderLinkedHashMap(128)))
}
return SerializationEnvironment.with(factory, AMQP_P2P_CONTEXT)
}

private class AMQPSerializationScheme(
cordappCustomSerializers: Set<SerializationCustomSerializer<*, *>>,
cordappSerializationWhitelists: Set<SerializationWhitelist>,
serializerFactoriesForContexts: AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>
) : AbstractAMQPSerializationScheme(cordappCustomSerializers, serializerFactoriesForContexts) {
) : AbstractAMQPSerializationScheme(cordappCustomSerializers, cordappSerializationWhitelists, serializerFactoriesForContexts) {
override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory {
throw UnsupportedOperationException()
}
Expand Down
2 changes: 2 additions & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ dependencies {

// required to use @Type annotation
compile "org.hibernate:hibernate-core:$hibernate_version"

compile group: "io.github.classgraph", name: "classgraph", version: class_graph_version
}

// TODO Consider moving it to quasar-utils in the future (introduced with PR-1388)
Expand Down
31 changes: 31 additions & 0 deletions core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package net.corda.core.internal

import io.github.classgraph.ClassGraph
import net.corda.core.CordaInternal
import net.corda.core.DeleteForDJVM
import net.corda.core.StubOutForDJVM
import kotlin.reflect.full.createInstance

/**
* Creates instances of all the classes in the classpath of the provided classloader, which implement the interface of the provided class.
* @param classloader the classloader, which will be searched for the classes.
* @param clazz the class of the interface, which the classes - to be returned - must implement.
*
* @return instances of the identified classes.
* @throws IllegalArgumentException if the classes found do not have proper constructors.
*
* Note: In order to be instantiated, the associated classes must:
* - be non-abstract
* - either be a Kotlin object or have a constructor with no parameters (or only optional ones)
*/
@StubOutForDJVM
fun <T: Any> loadClassesImplementing(classloader: ClassLoader, clazz: Class<T>): Set<T> {
return ClassGraph().addClassLoader(classloader)
.enableAllInfo()
.scan()
.getClassesImplementing(clazz.name)
.filterNot { it.isAbstract }
.mapNotNull { classloader.loadClass(it.name).asSubclass(clazz) }
.map { it.kotlin.objectInstance ?: it.kotlin.createInstance() }
.toSet()
}
4 changes: 4 additions & 0 deletions core/src/main/kotlin/net/corda/core/internal/CordaUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ import net.corda.core.node.services.vault.Builder
import net.corda.core.node.services.vault.Sort
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.internal.AttachmentsClassLoaderBuilder
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.transactions.WireTransaction
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.slf4j.MDC
import java.security.PublicKey
import java.util.jar.JarEntry
import java.util.jar.JarInputStream

// *Internal* Corda-specific utilities.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@file:KeepForDJVM
package net.corda.core.serialization

import co.paralleluniverse.io.serialization.Serialization
import net.corda.core.CordaInternal
import net.corda.core.DeleteForDJVM
import net.corda.core.DoNotImplement
Expand Down Expand Up @@ -160,6 +161,10 @@ interface SerializationContext {
* The use case we are serializing or deserializing for. See [UseCase].
*/
val useCase: UseCase
/**
* Additional custom serializers that will be made available during (de)serialization.
*/
val customSerializers: Set<SerializationCustomSerializer<*, *>>

/**
* Helper method to return a new context based on this context with the property added.
Expand Down Expand Up @@ -200,6 +205,11 @@ interface SerializationContext {
*/
fun withWhitelisted(clazz: Class<*>): SerializationContext

/**
* Helper method to return a new context based on this context with the given serializers added.
*/
fun withCustomSerializers(serializers: Set<SerializationCustomSerializer<*, *>>): SerializationContext

/**
* Helper method to return a new context based on this context but with serialization using the format this header sequence represents.
*/
Expand Down Expand Up @@ -335,3 +345,15 @@ interface ClassWhitelist {
interface EncodingWhitelist {
fun acceptEncoding(encoding: SerializationEncoding): Boolean
}

/**
* Helper method to return a new context based on this context with the given list of classes specifically whitelisted.
*/
fun SerializationContext.withWhitelist(classes: List<Class<*>>): SerializationContext {
var currentContext = this
classes.forEach {
clazz -> currentContext = currentContext.withWhitelisted(clazz)
}

return currentContext
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package net.corda.core.serialization.internal

import net.corda.core.CordaException
import net.corda.core.KeepForDJVM
import net.corda.core.internal.loadClassesImplementing
import net.corda.core.contracts.Attachment
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.TransactionVerificationException.OverlappingAttachmentsException
Expand All @@ -10,14 +11,19 @@ import net.corda.core.crypto.sha256
import net.corda.core.internal.*
import net.corda.core.internal.cordapp.targetPlatformVersion
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.MissingAttachmentsException
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.toUrl
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.net.*
import java.util.*

/**
* A custom ClassLoader that knows how to load classes from a set of attachments. The attachments themselves only
Expand Down Expand Up @@ -174,34 +180,38 @@ class AttachmentsClassLoader(attachments: List<Attachment>, parent: ClassLoader
}

/**
* This is just a factory that provides a cache to avoid constructing expensive [AttachmentsClassLoader]s.
* This is just a factory that provides caches to optimise expensive construction/loading of classloaders, serializers, whitelisted classes.
*/
@VisibleForTesting
internal object AttachmentsClassLoaderBuilder {

private const val ATTACHMENT_CLASSLOADER_CACHE_SIZE = 1000
private const val CACHE_SIZE = 1000

// This runs in the DJVM so it can't use caffeine.
private val cache: MutableMap<List<SecureHash>, AttachmentsClassLoader> = createSimpleCache<List<SecureHash>, AttachmentsClassLoader>(ATTACHMENT_CLASSLOADER_CACHE_SIZE)
.toSynchronised()

fun build(attachments: List<Attachment>): AttachmentsClassLoader {
return cache.computeIfAbsent(attachments.map { it.id }.sorted()) {
AttachmentsClassLoader(attachments)
}
}
private val cache: MutableMap<Set<SecureHash>, SerializationContext> = createSimpleCache(CACHE_SIZE)

fun <T> withAttachmentsClassloaderContext(attachments: List<Attachment>, block: (ClassLoader) -> T): T {
val attachmentIds = attachments.map { it.id }.toSet()

// Create classloader from the attachments.
val transactionClassLoader = AttachmentsClassLoaderBuilder.build(attachments)
val serializationContext = cache.computeIfAbsent(attachmentIds) {
// Create classloader and load serializers, whitelisted classes
val transactionClassLoader = AttachmentsClassLoader(attachments)
val serializers = loadClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java)
val whitelistedClasses = ServiceLoader.load(SerializationWhitelist::class.java, transactionClassLoader)
.flatMap { it.whitelist }
.toList()

// Create a new serializationContext for the current Transaction.
val transactionSerializationContext = SerializationFactory.defaultFactory.defaultContext.withPreventDataLoss().withClassLoader(transactionClassLoader)
// Create a new serializationContext for the current Transaction.
SerializationFactory.defaultFactory.defaultContext
.withPreventDataLoss()
.withClassLoader(transactionClassLoader)
.withWhitelist(whitelistedClasses)
.withCustomSerializers(serializers)
}

// Deserialize all relevant classes in the transaction classloader.
return SerializationFactory.defaultFactory.withCurrentContext(transactionSerializationContext) {
block(transactionClassLoader)
return SerializationFactory.defaultFactory.withCurrentContext(serializationContext) {
block(serializationContext.deserializationClassLoader)
}
}
}
Expand Down
Loading

0 comments on commit 5b34020

Please sign in to comment.