Skip to content

Commit

Permalink
Correctly unwrap MissingAttachmentException. (corda#1454)
Browse files Browse the repository at this point in the history
  • Loading branch information
rick-r3 authored Sep 8, 2017
1 parent 88a6002 commit 691d9ea
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ interface SerializationContext {
* Helper method to return a new context based on this context with the appropriate class loader constructed from the passed attachment identifiers.
* (Requires the attachment storage to have been enabled).
*/
@Throws(MissingAttachmentsException::class)
fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io.ByteArrayOutputStream
import java.io.NotSerializableException
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutionException

val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled"

Expand All @@ -42,19 +43,28 @@ data class SerializationContextImpl(override val preferredSerializationVersion:

private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build()

// We need to cache the AttachmentClassLoaders to avoid too many contexts, since the class loader is part of cache key for the context.
/**
* {@inheritDoc}
*
* We need to cache the AttachmentClassLoaders to avoid too many contexts, since the class loader is part of cache key for the context.
*/
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext {
properties[attachmentsClassLoaderEnabledPropertyName] as? Boolean ?: false || return this
val serializationContext = properties[serializationContextKey] as? SerializeAsTokenContextImpl ?: return this // Some tests don't set one.
return withClassLoader(cache.get(attachmentHashes) {
val missing = ArrayList<SecureHash>()
val attachments = ArrayList<Attachment>()
attachmentHashes.forEach { id ->
serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } ?: run { missing += id }
}
missing.isNotEmpty() && throw MissingAttachmentsException(missing)
AttachmentsClassLoader(attachments)
})
try {
return withClassLoader(cache.get(attachmentHashes) {
val missing = ArrayList<SecureHash>()
val attachments = ArrayList<Attachment>()
attachmentHashes.forEach { id ->
serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } ?: run { missing += id }
}
missing.isNotEmpty() && throw MissingAttachmentsException(missing)
AttachmentsClassLoader(attachments)
})
} catch (e: ExecutionException) {
// Caught from within the cache get, so unwrap.
throw e.cause!!
}
}

override fun withProperty(property: Any, value: Any): SerializationContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import net.corda.core.internal.declaredField
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationFactory
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.ByteSequence
Expand Down Expand Up @@ -377,4 +376,24 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
// Then deserialize with the attachment class loader associated with the attachment
serialized.deserialize(context = inboundContext)
}

@Test
fun `test loading a class with attachment missing during deserialization`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val storage = MockAttachmentStorage()
val attachmentRef = SecureHash.randomSHA256()
val outboundContext = SerializationFactory.defaultFactory.defaultContext.withClassLoader(child)
// Serialize with custom context to avoid populating the default context with the specially loaded class
val serialized = contract.serialize(context = outboundContext)

// Then deserialize with the attachment class loader associated with the attachment
val e = assertFailsWith(MissingAttachmentsException::class) {
// We currently ignore annotations in attachments, so manually whitelist.
val inboundContext = SerializationFactory.defaultFactory.defaultContext.withWhitelisted(contract.javaClass).withAttachmentStorage(storage).withAttachmentsClassLoader(listOf(attachmentRef))
serialized.deserialize(context = inboundContext)
}
assertEquals(attachmentRef, e.ids.single())
}
}

0 comments on commit 691d9ea

Please sign in to comment.