Skip to content

Commit

Permalink
Initial checkpoint when protocol is first added
Browse files Browse the repository at this point in the history
  • Loading branch information
shamsasari committed Jun 16, 2016
1 parent eb4c24a commit 860353c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.r3corda.contracts.testing

import com.r3corda.contracts.*
import com.r3corda.contracts.cash.Cash
import com.r3corda.contracts.cash.CASH_PROGRAM_ID
import com.r3corda.contracts.cash.Cash
import com.r3corda.core.contracts.Amount
import com.r3corda.core.contracts.Contract
import com.r3corda.core.contracts.DUMMY_PROGRAM_ID
import com.r3corda.core.contracts.DummyContract
import com.r3corda.core.crypto.NullPublicKey
import com.r3corda.core.crypto.Party
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.slf4j.LoggerFactory
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberScheduler, private val loggerName: String) : Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {

// These fields shouldn't be serialised, so they are marked @Transient.
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest, fiber: ProtocolStateMachineImpl<*>) -> Unit)? = null
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest) -> Unit)? = null
@Transient private var receivedPayload: Any? = null
@Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var actionOnEnd: () -> Unit
Expand Down Expand Up @@ -54,7 +54,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS

fun prepareForResumeWith(serviceHub: ServiceHubInternal,
receivedPayload: Any?,
suspendAction: (StateMachineManager.FiberRequest, ProtocolStateMachineImpl<*>) -> Unit) {
suspendAction: (StateMachineManager.FiberRequest) -> Unit) {
this.serviceHub = serviceHub
this.receivedPayload = receivedPayload
this.suspendAction = suspendAction
Expand Down Expand Up @@ -108,7 +108,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
private fun suspend(with: StateMachineManager.FiberRequest) {
parkAndSerialize { fiber, serializer ->
try {
suspendAction!!(with, this)
suspendAction!!(with)
} catch (t: Throwable) {
logger.warn("Captured exception which was swallowed by Quasar", t)
// TODO to throw or not to throw, that is the question
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
}
}

private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint?) {
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
stateMachines[psm] = checkpoint
psm.actionOnEnd = {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
Expand All @@ -199,9 +199,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
try {
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
checkpointStorage.addCheckpoint(checkpoint)
// Need to add before iterating in case of immediate completion
// TODO: create an initial checkpoint here
initFiber(fiber, null)
initFiber(fiber, checkpoint)
executor.executeASAP {
iterateStateMachine(fiber, null) {
fiber.start()
Expand Down Expand Up @@ -233,21 +234,19 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
receivedPayload: Any?,
resumeAction: (Any?) -> Unit) {
executor.checkOnThread()
psm.prepareForResumeWith(serviceHub, receivedPayload) { request, serialisedFiber ->
psm.prepareForResumeWith(serviceHub, receivedPayload) { request ->
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
onNextSuspend(psm, request, serialisedFiber)
onNextSuspend(psm, request)
}
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
resumeAction(receivedPayload)
}

private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>,
request: FiberRequest,
fiber: ProtocolStateMachineImpl<*>) {
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) {
// We have a request to do something: send, receive, or send-and-receive.
if (request is FiberRequest.ExpectingResponse<*>) {
// Prepare a listener on the network that runs in the background thread when we receive a message.
checkpointOnExpectingResponse(psm, request, serializeFiber(fiber))
checkpointOnExpectingResponse(psm, request)
}
// If a non-null payload to send was provided, send it now.
request.payload?.let {
Expand All @@ -267,11 +266,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
}
}

private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>,
request: FiberRequest.ExpectingResponse<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>) {
private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>, request: FiberRequest.ExpectingResponse<*>) {
executor.checkOnThread()
val topic = "${request.topic}.${request.sessionIDForReceive}"
val serialisedFiber = serializeFiber(psm)
updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null)
psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" }
iterateOnResponse(psm, request.responseType, serialisedFiber, topic) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import com.r3corda.node.services.persistence.DataVendingService
import com.r3corda.node.services.wallet.NodeWalletService
import java.time.Clock

class MockServices(
open class MockServices(
customWallet: WalletService? = null,
val keyManagement: KeyManagementService? = null,
val net: MessagingService? = null,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.r3corda.node.services.statemachine

import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.node.services.MockServices
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.network.InMemoryMessagingNetwork
import com.r3corda.node.utilities.AffinityExecutor
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Test
import java.util.*

class StateMachineManagerTests {

val checkpointStorage = RecordingCheckpointStorage()
val network = InMemoryMessagingNetwork().InMemoryMessaging(true, InMemoryMessagingNetwork.Handle(1, "mock"))
val smm = createManager()

@After
fun cleanUp() {
network.stop()
}

@Test
fun `newly added protocol is preserved on restart`() {
smm.add("mock", ProtocolWithoutCheckpoints())
// Ensure we're restoring from the original add checkpoint
assertThat(checkpointStorage.allCheckpoints).hasSize(1)
val restoredProtocol = createManager().run {
start()
findStateMachines(ProtocolWithoutCheckpoints::class.java).single().first
}
assertThat(restoredProtocol.protocolStarted).isTrue()
}

private fun createManager() = StateMachineManager(object : MockServices() {
override val networkService: MessagingService get() = network
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD)


private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {

@Transient var protocolStarted = false

@Suspendable
override fun call() {
protocolStarted = true
Fiber.park()
}
}


class RecordingCheckpointStorage : CheckpointStorage {

private val _checkpoints = ArrayList<Checkpoint>()
val allCheckpoints = ArrayList<Checkpoint>()

override fun addCheckpoint(checkpoint: Checkpoint) {
_checkpoints.add(checkpoint)
allCheckpoints.add(checkpoint)
}

override fun removeCheckpoint(checkpoint: Checkpoint) {
_checkpoints.remove(checkpoint)
}

override val checkpoints: Iterable<Checkpoint> get() = _checkpoints
}

}

0 comments on commit 860353c

Please sign in to comment.