Skip to content

Commit

Permalink
CORDA-3936: Add a fallback mechanism for Enums incorrectly serialised…
Browse files Browse the repository at this point in the history
… using their toString() method. (corda#6603)

* CORDA-3936: Add a fallback mechanism for Enums incorrectly serialised using their toString() method.

* Backport missing piece of Enum serializer from Corda 4.6.
  • Loading branch information
chrisr3 authored Aug 10, 2020
1 parent 56b574c commit 83f8e00
Show file tree
Hide file tree
Showing 15 changed files with 339 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class SandboxSerializerFactoryFactory(
localSerializerFactory = localSerializerFactory,
classLoader = classLoader,
mustPreserveDataWhenEvolving = context.preventDataLoss,
primitiveTypes = primitiveTypes
primitiveTypes = primitiveTypes,
baseTypes = localTypes
)

val remoteSerializerFactory = DefaultRemoteSerializerFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ fun createSandboxSerializationEnv(
@Suppress("unchecked_cast")
val isEnumPredicate = predicateFactory.apply(CheckEnum::class.java) as Predicate<Class<*>>
@Suppress("unchecked_cast")
val enumConstants = taskFactory.apply(DescribeEnum::class.java)
.andThen(taskFactory.apply(GetEnumNames::class.java))
val enumConstants = taskFactory.apply(DescribeEnum::class.java) as Function<Class<*>, Array<out Any>>
@Suppress("unchecked_cast")
val enumConstantNames = enumConstants.andThen(taskFactory.apply(GetEnumNames::class.java))
.andThen { (it as Array<out Any>).map(Any::toString) } as Function<Class<*>, List<String>>

val sandboxLocalTypes = BaseLocalTypes(
Expand All @@ -72,7 +73,8 @@ fun createSandboxSerializationEnv(
mapClass = classLoader.toSandboxClass(Map::class.java),
stringClass = classLoader.toSandboxClass(String::class.java),
isEnum = isEnumPredicate,
enumConstants = enumConstants
enumConstants = enumConstants,
enumConstantNames = enumConstantNames
)
val schemeBuilder = SandboxSerializationSchemeBuilder(
classLoader = classLoader,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ private class ConcreteEnumSerializer(
declaredType,
TypeIdentifier.forGenericType(declaredType),
memberNames,
emptyMap(),
emptyList(),
EnumTransforms.empty
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package net.corda.serialization.djvm

import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.serialization.serialize
import net.corda.serialization.djvm.SandboxType.KOTLIN
import net.corda.serialization.internal.amqp.CompositeType
import net.corda.serialization.internal.amqp.DeserializationInput
import net.corda.serialization.internal.amqp.RestrictedType
import net.corda.serialization.internal.amqp.TypeNotation
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.extension.ExtendWith
import org.junit.jupiter.api.fail
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.EnumSource
import java.util.function.Function

/**
* Corda 4.4 briefly serialised [Enum] values using [Enum.toString] rather
* than [Enum.name]. We need to be able to deserialise these values now
* that the bug has been fixed.
*/
@ExtendWith(LocalSerialization::class)
class DeserializeRemoteCustomisedEnumTest : TestBase(KOTLIN) {
@ParameterizedTest
@EnumSource(Broken::class)
fun `test deserialize broken enum with custom toString`(broken: Broken) {
val workingData = broken.serialize().rewriteEnumAsWorking()

sandbox {
_contextSerializationEnv.set(createSandboxSerializationEnv(classLoader))

val sandboxWorkingClass = classLoader.toSandboxClass(Working::class.java)
val sandboxWorkingValue = workingData.deserializeFor(classLoader)
assertThat(sandboxWorkingValue::class.java).isSameAs(sandboxWorkingClass)
assertThat(sandboxWorkingValue.toString()).isEqualTo(broken.label)
}
}

/**
* This function rewrites the [SerializedBytes] for a naked [Broken] object
* into the [SerializedBytes] that Corda 4.4 would generate for an equivalent
* [Working] object.
*/
@Suppress("unchecked_cast")
private fun SerializedBytes<Broken>.rewriteEnumAsWorking(): SerializedBytes<Working> {
val envelope = DeserializationInput.getEnvelope(this).apply {
val restrictedType = schema.types[0] as RestrictedType
(schema.types as MutableList<TypeNotation>)[0] = restrictedType.copy(
name = toWorking(restrictedType.name)
)
}
return SerializedBytes(envelope.write())
}

@ParameterizedTest
@EnumSource(Broken::class)
fun `test deserialize composed broken enum with custom toString`(broken: Broken) {
val brokenContainer = BrokenContainer(broken)
val workingData = brokenContainer.serialize().rewriteContainerAsWorking()

sandbox {
_contextSerializationEnv.set(createSandboxSerializationEnv(classLoader))

val sandboxContainer = workingData.deserializeFor(classLoader)

val taskFactory = classLoader.createRawTaskFactory()
val showWorkingData = taskFactory.compose(classLoader.createSandboxFunction()).apply(ShowWorkingData::class.java)
val result = showWorkingData.apply(sandboxContainer) ?: fail("Result cannot be null")

assertEquals("Working: label='${broken.label}', ordinal='${broken.ordinal}'", result.toString())
assertEquals(SANDBOX_STRING, result::class.java.name)
}
}

class ShowWorkingData : Function<WorkingContainer, String> {
override fun apply(input: WorkingContainer): String {
return with(input) {
"Working: label='${value.label}', ordinal='${value.ordinal}'"
}
}
}

/**
* This function rewrites the [SerializedBytes] for a [Broken]
* property that has been composed inside a [BrokenContainer].
* It will generate the [SerializedBytes] that Corda 4.4 would
* generate for an equivalent [WorkingContainer].
*/
@Suppress("unchecked_cast")
private fun SerializedBytes<BrokenContainer>.rewriteContainerAsWorking(): SerializedBytes<WorkingContainer> {
val envelope = DeserializationInput.getEnvelope(this).apply {
val compositeType = schema.types[0] as CompositeType
(schema.types as MutableList<TypeNotation>)[0] = compositeType.copy(
name = toWorking(compositeType.name),
fields = compositeType.fields.map { it.copy(type = toWorking(it.type)) }
)
val restrictedType = schema.types[1] as RestrictedType
(schema.types as MutableList<TypeNotation>)[1] = restrictedType.copy(
name = toWorking(restrictedType.name)
)
}
return SerializedBytes(envelope.write())
}

private fun toWorking(oldName: String): String = oldName.replace("Broken", "Working")

/**
* This is the enumerated type, as it actually exist.
*/
@Suppress("unused")
enum class Working(val label: String) {
ZERO("None"),
ONE("Once"),
TWO("Twice");

@Override
override fun toString(): String = label
}

@CordaSerializable
data class WorkingContainer(val value: Working)

/**
* This represents a broken serializer's view of the [Working]
* enumerated type, which would serialize using [Enum.toString]
* rather than [Enum.name].
*/
@Suppress("unused")
@CordaSerializable
enum class Broken(val label: String) {
None("None"),
Once("Once"),
Twice("Twice");

@Override
override fun toString(): String = label
}

@CordaSerializable
data class BrokenContainer(val value: Broken)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@file:JvmName("TestHelpers")
package net.corda.serialization.djvm

import net.corda.serialization.internal.SectionId
import net.corda.serialization.internal.amqp.Envelope
import net.corda.serialization.internal.amqp.alsoAsByteBuffer
import net.corda.serialization.internal.amqp.amqpMagic
import net.corda.serialization.internal.amqp.withDescribed
import net.corda.serialization.internal.amqp.withList
import org.apache.qpid.proton.codec.Data
import java.io.ByteArrayOutputStream

fun Envelope.write(): ByteArray {
val data = Data.Factory.create()
data.withDescribed(Envelope.DESCRIPTOR_OBJECT) {
withList {
putObject(obj)
putObject(schema)
putObject(transformsSchema)
}
}
return ByteArrayOutputStream().use {
amqpMagic.writeTo(it)
SectionId.DATA_AND_STOP.writeTo(it)
it.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode)
it.toByteArray()
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
package net.corda.serialization.internal.amqp

import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.SerializationContext
import net.corda.serialization.internal.model.LocalTypeInformation
import org.apache.qpid.proton.amqp.Symbol
import net.corda.serialization.internal.model.BaseLocalTypes
import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException
import java.lang.UnsupportedOperationException
import java.lang.reflect.Type
import java.util.*

/**
* Used whenever a deserialized enums fingerprint doesn't match the fingerprint of the generated
Expand Down Expand Up @@ -39,6 +35,7 @@ import java.util.*
class EnumEvolutionSerializer(
override val type: Type,
factory: LocalSerializerFactory,
private val baseLocalTypes: BaseLocalTypes,
private val conversions: Map<String, String>,
private val ordinals: Map<String, Int>) : AMQPSerializer<Any> {
override val typeDescriptor = factory.createDescriptor(type)
Expand All @@ -51,7 +48,7 @@ class EnumEvolutionSerializer(
val converted = conversions[enumName] ?: throw AMQPNotSerializableException(type, "No rule to evolve enum constant $type::$enumName")
val ordinal = ordinals[converted] ?: throw AMQPNotSerializableException(type, "Ordinal not found for enum value $type::$converted")

return type.asClass().enumConstants[ordinal]
return baseLocalTypes.enumConstants.apply(type.asClass())[ordinal]
}

override fun writeClassInfo(output: SerializationOutput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class DefaultEvolutionSerializerFactory(
private val localSerializerFactory: LocalSerializerFactory,
private val classLoader: ClassLoader,
private val mustPreserveDataWhenEvolving: Boolean,
override val primitiveTypes: Map<Class<*>, Class<*>>
override val primitiveTypes: Map<Class<*>, Class<*>>,
private val baseTypes: BaseLocalTypes
): EvolutionSerializerFactory {
// Invert the "primitive -> boxed primitive" mapping.
private val primitiveBoxedTypes: Map<Class<*>, Class<*>>
Expand Down Expand Up @@ -154,16 +155,16 @@ class DefaultEvolutionSerializerFactory(
val localTransforms = localTypeInformation.transforms
val transforms = if (remoteTransforms.size > localTransforms.size) remoteTransforms else localTransforms

val localOrdinals = localTypeInformation.members.asSequence().mapIndexed { ord, member -> member to ord }.toMap()
val remoteOrdinals = members.asSequence().mapIndexed { ord, member -> member to ord }.toMap()
val localOrdinals = localTypeInformation.members.mapIndexed { ord, member -> member to ord }.toMap()
val remoteOrdinals = members.mapIndexed { ord, member -> member to ord }.toMap()
val rules = transforms.defaults + transforms.renames

// We just trust our transformation rules not to contain cycles here.
tailrec fun findLocal(remote: String): String =
if (remote in localOrdinals) remote
else findLocal(rules[remote] ?: throw EvolutionSerializationException(
this,
"Cannot resolve local enum member $remote to a member of ${localOrdinals.keys} using rules $rules"
if (remote in localOrdinals.keys) remote
else localTypeInformation.fallbacks[remote] ?: findLocal(rules[remote] ?: throw EvolutionSerializationException(
this,
"Cannot resolve local enum member $remote to a member of ${localOrdinals.keys} using rules $rules"
))

val conversions = members.associate { it to findLocal(it) }
Expand All @@ -172,7 +173,7 @@ class DefaultEvolutionSerializerFactory(
if (constantsAreReordered(localOrdinals, convertedOrdinals)) throw EvolutionSerializationException(this,
"Constants have been reordered, additions must be appended to the end")

return EnumEvolutionSerializer(localTypeInformation.observedType, localSerializerFactory, conversions, localOrdinals)
return EnumEvolutionSerializer(localTypeInformation.observedType, localSerializerFactory, baseTypes, conversions, localOrdinals)
}

private fun constantsAreReordered(localOrdinals: Map<String, Int>, convertedOrdinals: Map<Int, String>): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ object SerializerFactoryBuilder {
mustPreserveDataWhenEvolving: Boolean): SerializerFactory {
val customSerializerRegistry = CachingCustomSerializerRegistry(descriptorBasedSerializerRegistry)

val localTypeModel = ConfigurableLocalTypeModel(
WhitelistBasedTypeModelConfiguration(
whitelist,
customSerializerRegistry))
val typeModelConfiguration = WhitelistBasedTypeModelConfiguration(whitelist, customSerializerRegistry)
val localTypeModel = ConfigurableLocalTypeModel(typeModelConfiguration)

val fingerPrinter = overrideFingerPrinter ?:
TypeModellingFingerPrinter(customSerializerRegistry, classCarpenter.classloader)
Expand All @@ -124,7 +122,8 @@ object SerializerFactoryBuilder {
localSerializerFactory,
classCarpenter.classloader,
mustPreserveDataWhenEvolving,
javaPrimitiveTypes
javaPrimitiveTypes,
typeModelConfiguration.baseTypes
) else NoEvolutionSerializerFactory

val remoteSerializerFactory = DefaultRemoteSerializerFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ private val DEFAULT_BASE_TYPES = BaseLocalTypes(
mapClass = Map::class.java,
stringClass = String::class.java,
isEnum = Predicate { clazz -> clazz.isEnum },
enumConstants = Function { clazz ->
enumConstants = Function { clazz -> clazz.enumConstants },
enumConstantNames = Function { clazz ->
(clazz as Class<out Enum<*>>).enumConstants.map(Enum<*>::name)
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ typealias PropertyName = String
* If a concrete type does not have a unique deserialization constructor, it is represented by [NonComposable], meaning
* that we know how to take it apart but do not know how to put it back together again.
*
* An array of any type is represented by [ArrayOf]. Enums are represented by [AnEnum].
* An array of any type is represented by [AnArray]. Enums are represented by [AnEnum].
*
* The type of [Any]/[java.lang.Object] is represented by [Top]. Unbounded wildcards, or wildcards whose upper bound is
* [Top], are represented by [Unknown]. Bounded wildcards are always resolved to their upper bounds, e.g.
Expand Down Expand Up @@ -178,6 +178,7 @@ sealed class LocalTypeInformation {
override val observedType: Class<*>,
override val typeIdentifier: TypeIdentifier,
val members: List<String>,
val fallbacks: Map<String, String>,
val interfaces: List<LocalTypeInformation>,
val transforms: EnumTransforms): LocalTypeInformation()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,22 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup,
baseTypes.mapClass.isAssignableFrom(type) -> AMap(type, typeIdentifier, Unknown, Unknown)
type === baseTypes.stringClass -> Atomic(type, typeIdentifier)
type.kotlin.javaPrimitiveType != null -> Atomic(type, typeIdentifier)
baseTypes.isEnum.test(type) -> baseTypes.enumConstants.apply(type).let { enumConstants ->
baseTypes.isEnum.test(type) -> baseTypes.enumConstantNames.apply(type).let { enumConstantNames ->
AnEnum(
type,
typeIdentifier,
enumConstants,
enumConstantNames,
/**
* Calculate "fallbacks" for any [Enum] incorrectly serialised
* as its [Enum.toString] value. We are only interested in the
* cases where these are different from [Enum.name].
* These fallbacks DO NOT contribute to this type's fingerprint.
*/
baseTypes.enumConstants.apply(type).map(Any::toString).mapIndexed { ord, fallback ->
fallback to enumConstantNames[ord]
}.filterNot { it.first == it.second }.toMap(),
buildInterfaceInformation(type),
getEnumTransforms(type, enumConstants)
getEnumTransforms(type, enumConstantNames)
)
}
type.kotlinObjectInstance != null -> Singleton(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,6 @@ class BaseLocalTypes(
val mapClass: Class<*>,
val stringClass: Class<*>,
val isEnum: Predicate<Class<*>>,
val enumConstants: Function<Class<*>, List<String>>
val enumConstants: Function<Class<*>, Array<out Any>>,
val enumConstantNames: Function<Class<*>, List<String>>
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package net.corda.serialization.internal.amqp
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.serialization.internal.EmptyWhitelist
import net.corda.serialization.internal.amqp.testutils.TestSerializationOutput
import net.corda.serialization.internal.amqp.testutils.deserialize
Expand Down Expand Up @@ -184,7 +183,7 @@ class EnumTests {
data class C(val a: OldBras2)

// DO NOT CHANGE THIS, it's important we serialise with a value that doesn't
// change position in the upated enum class
// change position in the updated enum class

// Original version of the class for the serialised version of this class
//
Expand Down
Loading

0 comments on commit 83f8e00

Please sign in to comment.