forked from corda/corda
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CORDA-683] Enable
receiveAll()
from Flows.
- Loading branch information
1 parent
20a30b3
commit 29a101c
Showing
11 changed files
with
377 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
115 changes: 115 additions & 0 deletions
115
core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
package net.corda.core.flows | ||
|
||
import co.paralleluniverse.fibers.Suspendable | ||
import net.corda.core.utilities.UntrustworthyData | ||
import net.corda.core.utilities.unwrap | ||
import net.corda.node.internal.InitiatedFlowFactory | ||
import net.corda.node.internal.StartedNode | ||
import kotlin.reflect.KClass | ||
|
||
/** | ||
* Allows to simplify writing flows that simply rend a message back to an initiating flow. | ||
*/ | ||
class Answer<out R : Any>(session: FlowSession, override val answer: R, closure: (result: R) -> Unit = {}) : SimpleAnswer<R>(session, closure) | ||
|
||
/** | ||
* Allows to simplify writing flows that simply rend a message back to an initiating flow. | ||
*/ | ||
abstract class SimpleAnswer<out R : Any>(private val session: FlowSession, private val closure: (result: R) -> Unit = {}) : FlowLogic<Unit>() { | ||
@Suspendable | ||
override fun call() { | ||
val tmp = answer | ||
closure(tmp) | ||
session.send(tmp) | ||
} | ||
|
||
protected abstract val answer: R | ||
} | ||
|
||
/** | ||
* A flow that does not do anything when triggered. | ||
*/ | ||
class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic<Unit>() { | ||
@Suspendable | ||
override fun call() = closure() | ||
} | ||
|
||
/** | ||
* Allows to register a flow of type [R] against an initiating flow of type [I]. | ||
*/ | ||
inline fun <I : FlowLogic<*>, reified R : FlowLogic<*>> StartedNode<*>.registerInitiatedFlow(initiatingFlowType: KClass<I>, crossinline construct: (session: FlowSession) -> R) { | ||
internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true) | ||
} | ||
|
||
/** | ||
* Allows to register a flow of type [Answer] against an initiating flow of type [I], returning a valure of type [R]. | ||
*/ | ||
inline fun <I : FlowLogic<*>, reified R : Any> StartedNode<*>.registerAnswer(initiatingFlowType: KClass<I>, value: R) { | ||
internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true) | ||
} | ||
|
||
/** | ||
* Extracts data from a [Map[FlowSession, UntrustworthyData<Any>]] without performing checks and casting to [R]. | ||
*/ | ||
@Suppress("UNCHECKED_CAST") | ||
infix fun <R : Any> Map<FlowSession, UntrustworthyData<Any>>.from(session: FlowSession): R = this[session]!!.unwrap { it as R } | ||
|
||
/** | ||
* Creates a [Pair([session], [Class])] from this [Class]. | ||
*/ | ||
infix fun <T : Class<out Any>> T.from(session: FlowSession): Pair<FlowSession, T> = session to this | ||
|
||
/** | ||
* Creates a [Pair([session], [Class])] from this [KClass]. | ||
*/ | ||
infix fun <T : Any> KClass<T>.from(session: FlowSession): Pair<FlowSession, Class<T>> = session to this.javaObjectType | ||
|
||
/** | ||
* Suspends until a message has been received for each session in the specified [sessions]. | ||
* | ||
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions. | ||
* | ||
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly | ||
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly | ||
* corrupted data in order to exploit your code. | ||
* | ||
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them. | ||
*/ | ||
@Suspendable | ||
fun FlowLogic<*>.receiveAll(session: Pair<FlowSession, Class<out Any>>, vararg sessions: Pair<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> { | ||
val allSessions = arrayOf(session, *sessions) | ||
allSessions.enforceNoDuplicates() | ||
return receiveAll(mapOf(*allSessions)) | ||
} | ||
|
||
/** | ||
* Suspends until a message has been received for each session in the specified [sessions]. | ||
* | ||
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types. | ||
* | ||
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly | ||
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly | ||
* corrupted data in order to exploit your code. | ||
* | ||
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions]. | ||
*/ | ||
@Suspendable | ||
fun <R : Any> FlowLogic<*>.receiveAll(receiveType: Class<R>, session: FlowSession, vararg sessions: FlowSession): List<UntrustworthyData<R>> = receiveAll(receiveType, listOf(session, *sessions)) | ||
|
||
/** | ||
* Suspends until a message has been received for each session in the specified [sessions]. | ||
* | ||
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types. | ||
* | ||
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly | ||
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly | ||
* corrupted data in order to exploit your code. | ||
* | ||
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions]. | ||
*/ | ||
@Suspendable | ||
inline fun <reified R : Any> FlowLogic<*>.receiveAll(session: FlowSession, vararg sessions: FlowSession): List<UntrustworthyData<R>> = receiveAll(R::class.javaObjectType, listOf(session, *sessions)) | ||
|
||
private fun Array<out Pair<FlowSession, Class<out Any>>>.enforceNoDuplicates() { | ||
require(this.size == this.toSet().size) { "A flow session can only appear once as argument." } | ||
} |
87 changes: 87 additions & 0 deletions
87
core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package net.corda.core.flows | ||
|
||
import co.paralleluniverse.fibers.Suspendable | ||
import net.corda.core.identity.Party | ||
import net.corda.core.utilities.UntrustworthyData | ||
import net.corda.core.utilities.getOrThrow | ||
import net.corda.core.utilities.unwrap | ||
import net.corda.testing.chooseIdentity | ||
import net.corda.testing.node.network | ||
import org.assertj.core.api.Assertions.assertThat | ||
import org.junit.Test | ||
|
||
class ReceiveMultipleFlowTests { | ||
@Test | ||
fun `receive all messages in parallel using map style`() { | ||
network(3) { nodes, _ -> | ||
val doubleValue = 5.0 | ||
nodes[1].registerAnswer(AlgorithmDefinition::class, doubleValue) | ||
val stringValue = "Thriller" | ||
nodes[2].registerAnswer(AlgorithmDefinition::class, stringValue) | ||
|
||
val flow = nodes[0].services.startFlow(ParallelAlgorithmMap(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity())) | ||
runNetwork() | ||
|
||
val result = flow.resultFuture.getOrThrow() | ||
|
||
assertThat(result).isEqualTo(doubleValue * stringValue.length) | ||
} | ||
} | ||
|
||
@Test | ||
fun `receive all messages in parallel using list style`() { | ||
network(3) { nodes, _ -> | ||
val value1 = 5.0 | ||
nodes[1].registerAnswer(ParallelAlgorithmList::class, value1) | ||
val value2 = 6.0 | ||
nodes[2].registerAnswer(ParallelAlgorithmList::class, value2) | ||
|
||
val flow = nodes[0].services.startFlow(ParallelAlgorithmList(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity())) | ||
runNetwork() | ||
val data = flow.resultFuture.getOrThrow() | ||
|
||
assertThat(data[0]).isEqualTo(value1) | ||
assertThat(data[1]).isEqualTo(value2) | ||
assertThat(data.fold(1.0) { a, b -> a * b }).isEqualTo(value1 * value2) | ||
} | ||
} | ||
|
||
class ParallelAlgorithmMap(doubleMember: Party, stringMember: Party) : AlgorithmDefinition(doubleMember, stringMember) { | ||
@Suspendable | ||
override fun askMembersForData(doubleMember: Party, stringMember: Party): Data { | ||
val doubleSession = initiateFlow(doubleMember) | ||
val stringSession = initiateFlow(stringMember) | ||
val rawData = receiveAll(Double::class from doubleSession, String::class from stringSession) | ||
return Data(rawData from doubleSession, rawData from stringSession) | ||
} | ||
} | ||
|
||
@InitiatingFlow | ||
class ParallelAlgorithmList(private val member1: Party, private val member2: Party) : FlowLogic<List<Double>>() { | ||
@Suspendable | ||
override fun call(): List<Double> { | ||
val session1 = initiateFlow(member1) | ||
val session2 = initiateFlow(member2) | ||
val data = receiveAll<Double>(session1, session2) | ||
return computeAnswer(data) | ||
} | ||
|
||
private fun computeAnswer(data: List<UntrustworthyData<Double>>): List<Double> { | ||
return data.map { element -> element.unwrap { it } } | ||
} | ||
} | ||
|
||
@InitiatingFlow | ||
abstract class AlgorithmDefinition(private val doubleMember: Party, private val stringMember: Party) : FlowLogic<Double>() { | ||
protected data class Data(val double: Double, val string: String) | ||
|
||
@Suspendable | ||
protected abstract fun askMembersForData(doubleMember: Party, stringMember: Party): Data | ||
|
||
@Suspendable | ||
override fun call(): Double { | ||
val (double, string) = askMembersForData(doubleMember, stringMember) | ||
return double * string.length | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.