Skip to content

Commit

Permalink
CORDA-806 Remove initialiseSerialization from rpcDriver (corda#2084)
Browse files Browse the repository at this point in the history
and fix a leak or two
andr3ej authored Nov 29, 2017
1 parent 2525fb5 commit 3c31fdf
Showing 12 changed files with 117 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -12,11 +12,14 @@ import net.corda.core.serialization.serialize
import net.corda.core.utilities.*
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.driver.poll
import net.corda.testing.internal.*
import org.apache.activemq.artemis.api.core.SimpleString
import org.junit.After
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Rule
import org.junit.Test
import rx.Observable
import rx.subjects.PublishSubject
@@ -26,6 +29,14 @@ import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger

class RPCStabilityTests {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule(true)
private val pool = Executors.newFixedThreadPool(10, testThreadFactory())
@After
fun shutdown() {
pool.shutdown()
}

object DummyOps : RPCOps {
override val protocolVersion = 0
@@ -197,9 +208,9 @@ class RPCStabilityTests {
val proxy = startRpcClient<LeakObservableOps>(server.get().broker.hostAndPort!!).get()
// Leak many observables
val N = 200
(1..N).toList().parallelStream().forEach {
proxy.leakObservable()
}
(1..N).map {
pool.fork { proxy.leakObservable(); Unit }
}.transpose().getOrThrow()
// In a loop force GC and check whether the server is notified
while (true) {
System.gc()
@@ -231,7 +242,7 @@ class RPCStabilityTests {
assertEquals("pong", client.ping())
serverFollower.shutdown()
startRpcServer<ReconnectOps>(ops = ops, customPort = serverPort).getOrThrow()
val pingFuture = ForkJoinPool.commonPool().fork(client::ping)
val pingFuture = pool.fork(client::ping)
assertEquals("pong", pingFuture.getOrThrow(10.seconds))
clientFollower.shutdown() // Driver would do this after the new server, causing hang.
}
Original file line number Diff line number Diff line change
@@ -6,14 +6,20 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.User
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.internal.RPCDriverExposedDSLInterface
import net.corda.testing.internal.rpcTestUser
import net.corda.testing.internal.startInVmRpcClient
import net.corda.testing.internal.startRpcClient
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.junit.Rule
import org.junit.runners.Parameterized

open class AbstractRPCTest {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule(true)

enum class RPCTestMode {
InVm,
Netty
Original file line number Diff line number Diff line change
@@ -5,19 +5,22 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.utilities.millis
import net.corda.core.crypto.random63BitValue
import net.corda.core.internal.concurrent.fork
import net.corda.core.internal.concurrent.transpose
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.getOrThrow
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.testing.internal.RPCDriverExposedDSLInterface
import net.corda.testing.internal.rpcDriver
import net.corda.testing.internal.testThreadFactory
import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet
import org.junit.After
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import rx.Observable
import rx.subjects.UnicastSubject
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ForkJoinPool
import java.util.concurrent.*

@RunWith(Parameterized::class)
class RPCConcurrencyTests : AbstractRPCTest() {
@@ -36,7 +39,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
}

class TestOpsImpl : TestOps {
class TestOpsImpl(private val pool: Executor) : TestOps {
private val latches = ConcurrentHashMap<Long, CountDownLatch>()
override val protocolVersion = 0

@@ -68,24 +71,22 @@ class RPCConcurrencyTests : AbstractRPCTest() {
val branches = if (depth == 0) {
Observable.empty<ObservableRose<Int>>()
} else {
val publish = UnicastSubject.create<ObservableRose<Int>>()
ForkJoinPool.commonPool().fork {
(1..branchingFactor).toList().parallelStream().forEach {
publish.onNext(getParallelObservableTree(depth - 1, branchingFactor))
UnicastSubject.create<ObservableRose<Int>>().also { publish ->
(1..branchingFactor).map {
pool.fork { publish.onNext(getParallelObservableTree(depth - 1, branchingFactor)) }
}.transpose().then {
it.getOrThrow()
publish.onCompleted()
}
publish.onCompleted()
}
publish
}
return ObservableRose(depth, branches)
}
}

private lateinit var testOpsImpl: TestOpsImpl
private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> {
testOpsImpl = TestOpsImpl()
return testProxy<TestOps>(
testOpsImpl,
TestOpsImpl(pool),
clientConfiguration = RPCClientConfiguration.default.copy(
reapInterval = 100.millis,
cacheConcurrencyLevel = 16
@@ -96,26 +97,30 @@ class RPCConcurrencyTests : AbstractRPCTest() {
)
}

private val pool = Executors.newFixedThreadPool(10, testThreadFactory())
@After
fun shutdown() {
pool.shutdown()
}

@Test
fun `call multiple RPCs in parallel`() {
rpcDriver {
val proxy = testProxy()
val numberOfBlockedCalls = 2
val numberOfDownsRequired = 100
val id = proxy.ops.newLatch(numberOfDownsRequired)
val done = CountDownLatch(numberOfBlockedCalls)
// Start a couple of blocking RPC calls
(1..numberOfBlockedCalls).forEach {
ForkJoinPool.commonPool().fork {
val done = (1..numberOfBlockedCalls).map {
pool.fork {
proxy.ops.waitLatch(id)
done.countDown()
}
}
}.transpose()
// Down the latch that the others are waiting for concurrently
(1..numberOfDownsRequired).toList().parallelStream().forEach {
proxy.ops.downLatch(id)
}
done.await()
(1..numberOfDownsRequired).map {
pool.fork { proxy.ops.downLatch(id) }
}.transpose().getOrThrow()
done.getOrThrow()
}
}

@@ -146,7 +151,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
fun ObservableRose<Int>.subscribeToAll() {
remainingLatch.countDown()
this.branches.subscribe { tree ->
(tree.value + 1..treeDepth - 1).forEach {
(tree.value + 1 until treeDepth).forEach {
require(it in depthsSeen) { "Got ${tree.value} before $it" }
}
depthsSeen.add(tree.value)
@@ -165,11 +170,11 @@ class RPCConcurrencyTests : AbstractRPCTest() {
val treeDepth = 2
val treeBranchingFactor = 10
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
val depthsSeen = ConcurrentHashSet<Int>()
fun ObservableRose<Int>.subscribeToAll() {
remainingLatch.countDown()
branches.subscribe { tree ->
(tree.value + 1..treeDepth - 1).forEach {
(tree.value + 1 until treeDepth).forEach {
require(it in depthsSeen) { "Got ${tree.value} before $it" }
}
depthsSeen.add(tree.value)
Original file line number Diff line number Diff line change
@@ -5,12 +5,18 @@ import net.corda.core.concurrent.CordaFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.*
import net.corda.core.utilities.getOrThrow
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.internal.rpcDriver
import net.corda.testing.internal.startRpcClient
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Rule
import org.junit.Test

class RPCFailureTests {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule(true)

class Unserializable
interface Ops : RPCOps {
fun getUnserializable(): Unserializable
9 changes: 9 additions & 0 deletions core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt
Original file line number Diff line number Diff line change
@@ -31,6 +31,8 @@ import java.time.Duration
import java.time.temporal.Temporal
import java.util.*
import java.util.Spliterator.*
import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit
import java.util.stream.IntStream
import java.util.stream.Stream
import java.util.stream.StreamSupport
@@ -307,3 +309,10 @@ fun TransactionBuilder.toLedgerTransaction(services: ServiceHub, serializationCo
val KClass<*>.packageName: String get() = java.`package`.name

fun URL.openHttpConnection(): HttpURLConnection = openConnection() as HttpURLConnection
/** Analogous to [Thread.join]. */
fun ExecutorService.join() {
shutdown() // Do not change to shutdownNow, tests use this method to assert the executor has no more tasks.
while (!awaitTermination(1, TimeUnit.SECONDS)) {
// Try forever. Do not give up, tests use this method to assert the executor has no more tasks.
}
}
Original file line number Diff line number Diff line change
@@ -133,7 +133,7 @@ class ContractUpgradeFlowTest {

@Test
fun `2 parties contract upgrade using RPC`() {
rpcDriver(initialiseSerialization = false) {
rpcDriver {
// Create dummy contract.
val twoPartyDummyContract = DummyContract.generateInitial(0, notary, alice.ref(1), bob.ref(1))
val signedByA = aliceNode.services.signInitialTransaction(twoPartyDummyContract)
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@ import org.junit.runners.model.Statement
import org.slf4j.Logger
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
import kotlin.test.assertNull

@@ -23,10 +22,7 @@ private fun <T> withSingleThreadExecutor(callable: ExecutorService.() -> T) = Ex
fork {}.getOrThrow() // Start the thread.
callable()
} finally {
shutdown()
while (!awaitTermination(1, TimeUnit.SECONDS)) {
// Do nothing.
}
join()
}
}

Original file line number Diff line number Diff line change
@@ -2,13 +2,13 @@ package net.corda.core.internal.concurrent

import com.nhaarman.mockito_kotlin.*
import net.corda.core.concurrent.CordaFuture
import net.corda.core.internal.join
import net.corda.core.utilities.getOrThrow
import net.corda.testing.rigorousMock
import org.assertj.core.api.Assertions
import org.junit.Test
import org.slf4j.Logger
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.test.assertEquals
import kotlin.test.assertFalse
@@ -108,10 +108,7 @@ class CordaFutureTest {
val throwable = Exception("Boom")
val executor = Executors.newSingleThreadExecutor()
executor.fork { throw throwable }.andForget(log)
executor.shutdown()
while (!executor.awaitTermination(1, TimeUnit.SECONDS)) {
// Do nothing.
}
executor.join()
verify(log).error(any(), same(throwable))
}

Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ import net.corda.core.context.Trace.InvocationId
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.LazyStickyPool
import net.corda.core.internal.LifeCycle
import net.corda.core.internal.join
import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT
@@ -207,6 +208,7 @@ class RPCServer(
}

fun close() {
observationSendExecutor?.join()
reaperScheduledFuture?.cancel(false)
rpcExecutor?.shutdownNow()
reaperExecutor?.shutdownNow()
Original file line number Diff line number Diff line change
@@ -230,7 +230,6 @@ fun <A> rpcDriver(
debugPortAllocation: PortAllocation = globalDebugPortAllocation,
systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false,
initialiseSerialization: Boolean = true,
startNodesInProcess: Boolean = false,
waitForNodesToFinish: Boolean = false,
extraCordappPackagesToScan: List<String> = emptyList(),
@@ -254,7 +253,7 @@ fun <A> rpcDriver(
),
coerce = { it },
dsl = dsl,
initialiseSerialization = initialiseSerialization
initialiseSerialization = false
)

private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 {
Original file line number Diff line number Diff line change
@@ -1,26 +1,50 @@
package net.corda.testing

import com.nhaarman.mockito_kotlin.doNothing
import com.nhaarman.mockito_kotlin.whenever
import com.nhaarman.mockito_kotlin.*
import net.corda.client.rpc.internal.KryoClientSerializationScheme
import net.corda.core.internal.staticField
import net.corda.core.serialization.internal.*
import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.nodeapi.internal.serialization.*
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.testing.common.internal.asContextEnv
import net.corda.testing.internal.testThreadFactory
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnector
import org.junit.rules.TestRule
import org.junit.runner.Description
import org.junit.runners.model.Statement
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors

private val inVMExecutors = ConcurrentHashMap<SerializationEnvironment, ExecutorService>()

/** @param inheritable whether new threads inherit the environment, use sparingly. */
class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
companion object {
init {
// Can't turn it off, and it creates threads that do serialization, so hack it:
InVMConnector::class.staticField<ExecutorService>("threadPoolExecutor").value = rigorousMock<ExecutorService>().also {
doAnswer {
inVMExecutors.computeIfAbsent(effectiveSerializationEnv) {
Executors.newCachedThreadPool(testThreadFactory(true)) // Close enough to what InVMConnector makes normally.
}.execute(it.arguments[0] as Runnable)
}.whenever(it).execute(any())
}
}
}

lateinit var env: SerializationEnvironment
override fun apply(base: Statement, description: Description): Statement {
env = createTestSerializationEnv(description.toString())
return object : Statement() {
override fun evaluate() = env.asContextEnv(inheritable) {
base.evaluate()
override fun evaluate() {
try {
env.asContextEnv(inheritable) { base.evaluate() }
} finally {
inVMExecutors.remove(env)
}
}
}
}
@@ -59,6 +83,7 @@ fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment {
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("<global>") {
override fun unset() {
_globalSerializationEnv.set(null)
inVMExecutors.remove(this)
}
}.also {
_globalSerializationEnv.set(it)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package net.corda.testing.internal

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicInteger

private val familyToNextPoolNumber = ConcurrentHashMap<String, AtomicInteger>()
fun Any.testThreadFactory(useEnclosingClassName: Boolean = false): ThreadFactory {
val poolFamily = javaClass.let { (if (useEnclosingClassName) it.enclosingClass else it).simpleName }
val poolNumber = familyToNextPoolNumber.computeIfAbsent(poolFamily) { AtomicInteger(1) }.getAndIncrement()
val nextThreadNumber = AtomicInteger(1)
return ThreadFactory { task ->
Thread(task, "$poolFamily-$poolNumber-${nextThreadNumber.getAndIncrement()}")
}
}

0 comments on commit 3c31fdf

Please sign in to comment.