Skip to content

Commit

Permalink
CORDA-3188: Ignore synthetic and static fields when searching for sta…
Browse files Browse the repository at this point in the history
…te pointers (corda#5439)
  • Loading branch information
shamsasari authored Sep 6, 2019
1 parent 7ef9a8d commit cedb290
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
32 changes: 14 additions & 18 deletions core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import java.util.*
* TODO: Doesn't handle calculated properties. Add support for this.
*/
class StatePointerSearch(val state: ContractState) {
// Classes in these packages should not be part of a search.
private val blackListedPackages = setOf("java.", "javax.", "org.bouncycastle.", "net.i2p.crypto.")
private companion object {
// Classes in these packages should not be part of a search.
private val blackListedPackages = setOf("java.", "javax.", "org.bouncycastle.", "net.i2p.crypto.")
}

// Type required for traversal.
private data class FieldWithObject(val obj: Any, val field: Field)
Expand All @@ -21,16 +23,17 @@ class StatePointerSearch(val state: ContractState) {
private val statePointers = mutableSetOf<StatePointer<*>>()

// Record seen objects to avoid getting stuck in loops.
private val seenObjects = Collections.newSetFromMap(IdentityHashMap<Any, Boolean>()).apply { add(state) }
private val seenObjects = Collections.newSetFromMap(IdentityHashMap<Any, Boolean>())

// Queue of fields to search.
private val fieldQueue = ArrayDeque<FieldWithObject>().apply { addAllFields(state) }
private val fieldQueue = ArrayDeque<FieldWithObject>()

// Helper for adding all fields to the queue.
private fun ArrayDeque<FieldWithObject>.addAllFields(obj: Any) {
private fun addAllFields(obj: Any) {
val fields = FieldUtils.getAllFieldsList(obj::class.java)

val fieldsWithObjects = fields.mapNotNull { field ->
fields.mapNotNullTo(fieldQueue) { field ->
if (field.isSynthetic || field.isStatic) return@mapNotNullTo null
// Ignore classes which have not been loaded.
// Assumption: all required state classes are already loaded.
val packageName = field.type.packageNameOrNull
Expand All @@ -40,11 +43,10 @@ class StatePointerSearch(val state: ContractState) {
FieldWithObject(obj, field)
}
}
addAll(fieldsWithObjects)
}

private fun handleIterable(iterable: Iterable<*>) {
iterable.forEach { obj -> handleObject(obj) }
iterable.forEach(::handleObject)
}

private fun handleMap(map: Map<*, *>) {
Expand All @@ -55,31 +57,25 @@ class StatePointerSearch(val state: ContractState) {
}

private fun handleObject(obj: Any?) {
if (obj == null) return
seenObjects.add(obj)
if (obj == null || !seenObjects.add(obj)) return
when (obj) {
is Map<*, *> -> handleMap(obj)
is StatePointer<*> -> statePointers.add(obj)
is Iterable<*> -> handleIterable(obj)
else -> {
val packageName = obj.javaClass.packageNameOrNull ?: ""
val isBlackListed = blackListedPackages.any { packageName.startsWith(it) }
if (isBlackListed.not()) fieldQueue.addAllFields(obj)
if (!isBlackListed) addAllFields(obj)
}
}
}

private fun handleField(obj: Any, field: Field) {
val newObj = field.get(obj) ?: return
if (newObj in seenObjects) return
handleObject(newObj)
}

fun search(): Set<StatePointer<*>> {
handleObject(state)
while (fieldQueue.isNotEmpty()) {
val (obj, field) = fieldQueue.pop()
field.isAccessible = true
handleField(obj, field)
handleObject(field.get(obj))
}
return statePointers
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import net.corda.core.crypto.NullKeys
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.AnonymousParty
import net.corda.core.utilities.OpaqueBytes
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import kotlin.test.assertEquals

Expand Down Expand Up @@ -32,6 +33,15 @@ class StatePointerSearchTests {
override val participants: List<AbstractParty> get() = listOf()
}

private data class StateWithStaticField(val blah: Int) : ContractState {
companion object {
@JvmStatic
val pointer = LinearPointer(UniqueIdentifier(), LinearState::class.java)
}

override val participants: List<AbstractParty> get() = listOf()
}

@Test
fun `find pointer in state with generic type`() {
val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java)
Expand Down Expand Up @@ -74,4 +84,9 @@ class StatePointerSearchTests {
assertEquals(results, setOf(linearPointer))
}

@Test
fun `ignore static fields`() {
val results = StatePointerSearch(StateWithStaticField(1)).search()
assertThat(results).isEmpty()
}
}

0 comments on commit cedb290

Please sign in to comment.