Skip to content

Commit

Permalink
CORDA-1839 - Remove race condition between trackBy and notifyAll (cor…
Browse files Browse the repository at this point in the history
…da#4412)

* CORDA-1839 - Remove race condition between trackBy and notifyAll

* Fix null check

* Improve filtering

* Switch equality test to refs

* Refine filtering of seen updates

* Add entry in the changelog

* Address comments
  • Loading branch information
dimosr authored and rick-r3 committed Dec 21, 2018
1 parent 39e5dc5 commit 8ac32f5
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 29 deletions.
18 changes: 18 additions & 0 deletions core/src/test/kotlin/net/corda/core/internal/InternalUtilsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import net.corda.core.crypto.SecureHash
import org.assertj.core.api.Assertions.assertThat
import org.junit.Assert.assertArrayEquals
import org.junit.Test
import rx.subjects.PublishSubject
import java.util.*
import java.util.stream.IntStream
import java.util.stream.Stream
Expand Down Expand Up @@ -101,6 +102,23 @@ open class InternalUtilsTest {
assertThat(PrivateClass::class.java.kotlinObjectInstance).isNull()
}

@Test
fun `bufferUntilSubscribed delays emission until the first subscription`() {
val sourceSubject: PublishSubject<Int> = PublishSubject.create<Int>()
val bufferedObservable: rx.Observable<Int> = uncheckedCast(sourceSubject.bufferUntilSubscribed())

sourceSubject.onNext(1)

var itemsFromBufferedObservable = mutableSetOf<Int>()
bufferedObservable.subscribe{itemsFromBufferedObservable.add(it)}

var itemsFromNonBufferedObservable = mutableSetOf<Int>()
sourceSubject.subscribe{itemsFromNonBufferedObservable.add(it)}

assertThat(itemsFromBufferedObservable.contains(1))
assertThat(itemsFromNonBufferedObservable).doesNotContain(1)
}

@Test
fun `test SHA-256 hash for InputStream`() {
val contents = arrayOfJunk(DEFAULT_BUFFER_SIZE * 2 + DEFAULT_BUFFER_SIZE / 2)
Expand Down
2 changes: 2 additions & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ release, see :doc:`upgrade-notes`.

Unreleased
----------
* Fixed race condition between ``NodeVaultService.trackBy`` and ``NodeVaultService.notifyAll``, where there could be states that were not reflected
in the data feed returned from ``trackBy`` (either in the query's result snapshot or the observable).

* TimedFlows (only used by the notary client flow) will never give up trying to reach the notary, as this would leave the states
in the notarisation request in an undefined state (unknown whether the spend has been notarised, i.e. has happened, or not). Also,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,35 +49,33 @@ class CashPaymentFlowTests {
val expectedPayment = 500.DOLLARS
val expectedChange = 1500.DOLLARS

bankOfCordaNode.transaction {
// Register for vault updates
val criteria = QueryCriteria.VaultQueryCriteria(status = Vault.StateStatus.ALL)
val (_, vaultUpdatesBoc) = bankOfCordaNode.services.vaultService.trackBy<Cash.State>(criteria)
val (_, vaultUpdatesBankClient) = aliceNode.services.vaultService.trackBy<Cash.State>(criteria)
// Register for vault updates
val criteria = QueryCriteria.VaultQueryCriteria(status = Vault.StateStatus.ALL)
val (_, vaultUpdatesBoc) = bankOfCordaNode.services.vaultService.trackBy<Cash.State>(criteria)
val (_, vaultUpdatesBankClient) = aliceNode.services.vaultService.trackBy<Cash.State>(criteria)

val future = bankOfCordaNode.startFlow(CashPaymentFlow(expectedPayment, payTo))
mockNet.runNetwork()
future.getOrThrow()
val future = bankOfCordaNode.startFlow(CashPaymentFlow(expectedPayment, payTo))
mockNet.runNetwork()
future.getOrThrow()

// Check Bank of Corda vault updates - we take in some issued cash and split it into $500 to the notary
// and $1,500 back to us, so we expect to consume one state, produce one state for our own vault
vaultUpdatesBoc.expectEvents {
expect { (consumed, produced) ->
assertThat(consumed).hasSize(1)
assertThat(produced).hasSize(1)
val changeState = produced.single().state.data
assertEquals(expectedChange.`issued by`(bankOfCorda.ref(ref)), changeState.amount)
}
// Check Bank of Corda vault updates - we take in some issued cash and split it into $500 to the notary
// and $1,500 back to us, so we expect to consume one state, produce one state for our own vault
vaultUpdatesBoc.expectEvents {
expect { (consumed, produced) ->
assertThat(consumed).hasSize(1)
assertThat(produced).hasSize(1)
val changeState = produced.single().state.data
assertEquals(expectedChange.`issued by`(bankOfCorda.ref(ref)), changeState.amount)
}
}

// Check notary node vault updates
vaultUpdatesBankClient.expectEvents {
expect { (consumed, produced) ->
assertThat(consumed).isEmpty()
assertThat(produced).hasSize(1)
val paymentState = produced.single().state.data
assertEquals(expectedPayment.`issued by`(bankOfCorda.ref(ref)), paymentState.amount)
}
// Check notary node vault updates
vaultUpdatesBankClient.expectEvents {
expect { (consumed, produced) ->
assertThat(consumed).isEmpty()
assertThat(produced).hasSize(1)
val paymentState = produced.single().state.data
assertEquals(expectedPayment.`issued by`(bankOfCorda.ref(ref)), paymentState.amount)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
import net.corda.nodeapi.internal.persistence.currentDBSession
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import org.hibernate.Session
import rx.Observable
import rx.subjects.PublishSubject
Expand Down Expand Up @@ -559,14 +560,37 @@ class NodeVaultService(
}
}

/**
* Returns a [DataFeed] containing the results of the provided query, along with the associated observable, containing any subsequent updates.
*
* Note that this method can be invoked concurrently with [NodeVaultService.notifyAll], which means there could be race conditions between reads
* performed here and writes performed there. These are prevented, using the following approach:
* - Observable updates emitted by [NodeVaultService.notifyAll] are buffered until the transaction's commit point
* This means that it's as if publication is performed, after the transaction is committed.
* - Observable updates tracked by [NodeVaultService._trackBy] are buffered before the transaction (for the provided query) is open
* and until the client's subscription. So, it's as if the customer is subscribed to the observable before the read's transaction is open.
*
* The combination of the 2 conditions described above guarantee that there can be no possible interleaving, where some states are not observed in the query
* (i.e. because read transaction opens, before write transaction is closed) and at the same time not included in the observable (i.e. because subscription
* is done before the publication of updates). However, this guarantee cannot be provided, in cases where the client invokes [VaultService.trackBy] with an open
* transaction.
*/
@Throws(VaultQueryException::class)
override fun <T : ContractState> _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> {
return mutex.locked {
val updates: Observable<Vault.Update<T>> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed())
if (contextTransactionOrNull != null) {
log.warn("trackBy is called with an already existing, open DB transaction. As a result, there might be states missing from both the snapshot and observable, included in the returned data feed, because of race conditions.")
}
val snapshotResults = _queryBy(criteria, paging, sorting, contractStateType)
val updates: Observable<Vault.Update<T>> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed()
.filter { it.containsType(contractStateType, snapshotResults.stateTypes) }
.map { filterContractStates(it, contractStateType) })
DataFeed(snapshotResults, updates)
val snapshotStatesRefs = snapshotResults.statesMetadata.map { it.ref }.toSet()
val snapshotConsumedStatesRefs = snapshotResults.statesMetadata.filter { it.consumedTime != null }
.map { it.ref }.toSet()
val filteredUpdates = updates.filter { it.containsType(contractStateType, snapshotResults.stateTypes) }
.map { filterContractStates(it, contractStateType) }
.filter { !hasBeenSeen(it, snapshotStatesRefs, snapshotConsumedStatesRefs) }

DataFeed(snapshotResults, filteredUpdates)
}
}

Expand All @@ -577,6 +601,25 @@ class NodeVaultService(
private fun <T : ContractState> filterByContractState(contractStateType: Class<out T>, stateAndRefs: Set<StateAndRef<T>>) =
stateAndRefs.filter { contractStateType.isAssignableFrom(it.state.data.javaClass) }.toSet()

/**
* Filters out updates that have been seen, aka being reflected in the query's result snapshot.
*
* An update is reflected in the snapshot, if both of the following conditions hold:
* - all the states produced by the update are included in the snapshot (regardless of whether they are consumed).
* - all the states consumed by the update are included in the snapshot, AND they are consumed.
*
* Note: An update can contain multiple transactions (with netting performed on them). As a result, some of these transactions
* can be included in the snapshot result, while some are not. In this case, since we are not capable of reverting the netting and doing
* partial exclusion, we decide to return some more updates, instead of losing them completely (not returning them either in
* the snapshot or in the observable).
*/
private fun <T: ContractState> hasBeenSeen(update: Vault.Update<T>, snapshotStatesRefs: Set<StateRef>, snapshotConsumedStatesRefs: Set<StateRef>): Boolean {
val updateProducedStatesRefs = update.produced.map { it.ref }.toSet()
val updateConsumedStatesRefs = update.consumed.map { it.ref }.toSet()

return snapshotStatesRefs.containsAll(updateProducedStatesRefs) && snapshotConsumedStatesRefs.containsAll(updateConsumedStatesRefs)
}

private fun getSession() = database.currentOrNew().session
/**
* Derive list from existing vault states and then incrementally update using vault observables
Expand Down

0 comments on commit 8ac32f5

Please sign in to comment.