Skip to content

Commit

Permalink
CORDA-1217 Replace Guava caches with Caffeine (corda#2818)
Browse files Browse the repository at this point in the history
  • Loading branch information
rick-r3 authored Mar 14, 2018
1 parent 8f750c0 commit a24a210
Show file tree
Hide file tree
Showing 21 changed files with 112 additions and 117 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ buildscript {
ext.log4j_version = '2.9.1'
ext.bouncycastle_version = constants.getProperty("bouncycastleVersion")
ext.guava_version = constants.getProperty("guavaVersion")
ext.caffeine_version = constants.getProperty("caffeineVersion")
ext.okhttp_version = '3.5.0'
ext.netty_version = '4.1.9.Final'
ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package net.corda.client.jfx.model

import com.google.common.cache.CacheBuilder
import com.google.common.cache.CacheLoader
import com.github.benmanes.caffeine.cache.Caffeine
import javafx.beans.value.ObservableValue
import javafx.collections.FXCollections
import javafx.collections.ObservableList
Expand Down Expand Up @@ -32,8 +31,8 @@ class NetworkIdentityModel {

private val rpcProxy by observableValue(NodeMonitorModel::proxyObservable)

private val identityCache = CacheBuilder.newBuilder()
.build<PublicKey, ObservableValue<NodeInfo?>>(CacheLoader.from { publicKey ->
private val identityCache = Caffeine.newBuilder()
.build<PublicKey, ObservableValue<NodeInfo?>>({ publicKey ->
publicKey?.let { rpcProxy.map { it?.nodeInfoFromParty(AnonymousParty(publicKey)) } }
})
val notaries = ChosenList(rpcProxy.map { FXCollections.observableList(it?.notaryIdentities() ?: emptyList()) })
Expand All @@ -42,5 +41,5 @@ class NetworkIdentityModel {
.filtered { it.legalIdentities.all { it !in notaries } }
val myIdentity = rpcProxy.map { it?.nodeInfo()?.legalIdentitiesAndCerts?.first()?.party }

fun partyFromPublicKey(publicKey: PublicKey): ObservableValue<NodeInfo?> = identityCache[publicKey]
fun partyFromPublicKey(publicKey: PublicKey): ObservableValue<NodeInfo?> = identityCache[publicKey]!!
}
3 changes: 3 additions & 0 deletions client/rpc/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ dependencies {
compile project(':core')
compile project(':node-api')

// For caches rather than guava
compile "com.github.ben-manes.caffeine:caffeine:$caffeine_version"

// Unit testing helpers.
testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version"
testCompile "junit:junit:$junit_version"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ data class RPCClientConfiguration(
val reapInterval: Duration,
/** The number of threads to use for observations (for executing [Observable.onNext]) */
val observationExecutorPoolSize: Int,
/**
* Determines the concurrency level of the Observable Cache. This is exposed because it implicitly determines
* the limit on the number of leaked observables reaped because of garbage collection per reaping.
* See the implementation of [com.google.common.cache.LocalCache] for details.
*/
val cacheConcurrencyLevel: Int,
/** The retry interval of artemis connections in milliseconds */
val connectionRetryInterval: Duration,
/** The retry interval multiplier for exponential backoff */
Expand All @@ -71,7 +65,6 @@ data class RPCClientConfiguration(
trackRpcCallSites = false,
reapInterval = 1.seconds,
observationExecutorPoolSize = 4,
cacheConcurrencyLevel = 8,
connectionRetryInterval = 5.seconds,
connectionRetryIntervalMultiplier = 1.5,
connectionMaxRetryInterval = 3.minutes,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package net.corda.client.rpc.internal

import co.paralleluniverse.common.util.SameThreadExecutor
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import com.google.common.cache.RemovalCause
import com.google.common.cache.RemovalListener
import com.github.benmanes.caffeine.cache.Cache
import com.github.benmanes.caffeine.cache.Caffeine
import com.github.benmanes.caffeine.cache.RemovalCause
import com.github.benmanes.caffeine.cache.RemovalListener
import com.google.common.util.concurrent.SettableFuture
import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.client.rpc.RPCException
Expand Down Expand Up @@ -142,10 +143,10 @@ class RPCClientProxyHandler(
private val serializationContextWithObservableContext = RpcClientObservableSerializer.createContext(serializationContext, observableContext)

private fun createRpcObservableMap(): RpcObservableMap {
val onObservableRemove = RemovalListener<InvocationId, UnicastSubject<Notification<*>>> {
val observableId = it.key!!
val onObservableRemove = RemovalListener<InvocationId, UnicastSubject<Notification<*>>> { key, value, cause ->
val observableId = key!!
val rpcCallSite = callSiteMap?.remove(observableId)
if (it.cause == RemovalCause.COLLECTED) {
if (cause == RemovalCause.COLLECTED) {
log.warn(listOf(
"A hot observable returned from an RPC was never subscribed to.",
"This wastes server-side resources because it was queueing observations for retrieval.",
Expand All @@ -156,10 +157,9 @@ class RPCClientProxyHandler(
}
observablesToReap.locked { observables.add(observableId) }
}
return CacheBuilder.newBuilder().
return Caffeine.newBuilder().
weakValues().
removalListener(onObservableRemove).
concurrencyLevel(rpcConfiguration.cacheConcurrencyLevel).
removalListener(onObservableRemove).executor(SameThreadExecutor.getExecutor()).
build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
return testProxy<TestOps>(
TestOpsImpl(pool),
clientConfiguration = RPCClientConfiguration.default.copy(
reapInterval = 100.millis,
cacheConcurrencyLevel = 16
reapInterval = 100.millis
),
serverConfiguration = RPCServerConfiguration.default.copy(
rpcThreadPoolSize = 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class RPCPerformanceTests : AbstractRPCTest() {
rpcDriver {
val proxy = testProxy(
RPCClientConfiguration.default.copy(
cacheConcurrencyLevel = 16,
observationExecutorPoolSize = 2
),
RPCServerConfiguration.default.copy(
Expand Down Expand Up @@ -127,8 +126,7 @@ class RPCPerformanceTests : AbstractRPCTest() {
val metricRegistry = startReporter(shutdownManager)
val proxy = testProxy(
RPCClientConfiguration.default.copy(
reapInterval = 1.seconds,
cacheConcurrencyLevel = 16
reapInterval = 1.seconds
),
RPCServerConfiguration.default.copy(
rpcThreadPoolSize = 8
Expand Down
1 change: 1 addition & 0 deletions constants.properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ typesafeConfigVersion=1.3.1
jsr305Version=3.0.2
artifactoryPluginVersion=4.4.18
snakeYamlVersion=1.19
caffeineVersion=2.6.2
3 changes: 3 additions & 0 deletions node-api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ dependencies {
// Pure-Java Snappy compression
compile 'org.iq80.snappy:snappy:0.4'

// For caches rather than guava
compile "com.github.ben-manes.caffeine:caffeine:$caffeine_version"

// Unit testing helpers.
testCompile "junit:junit:$junit_version"
testCompile "org.assertj:assertj-core:$assertj_version"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package net.corda.nodeapi.internal

import com.google.common.cache.CacheBuilder
import com.google.common.cache.CacheLoader
import com.github.benmanes.caffeine.cache.CacheLoader
import com.github.benmanes.caffeine.cache.Caffeine
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
Expand All @@ -11,11 +11,11 @@ import java.util.concurrent.atomic.AtomicLong
*/
class DeduplicationChecker(cacheExpiry: Duration) {
// dedupe identity -> watermark cache
private val watermarkCache = CacheBuilder.newBuilder()
private val watermarkCache = Caffeine.newBuilder()
.expireAfterAccess(cacheExpiry.toNanos(), TimeUnit.NANOSECONDS)
.build(WatermarkCacheLoader)

private object WatermarkCacheLoader : CacheLoader<Any, AtomicLong>() {
private object WatermarkCacheLoader : CacheLoader<Any, AtomicLong> {
override fun load(key: Any) = AtomicLong(-1)
}

Expand All @@ -25,6 +25,7 @@ class DeduplicationChecker(cacheExpiry: Duration) {
* @return true if the message is unique, false if it's a duplicate.
*/
fun checkDuplicateMessageId(identity: Any, sequenceNumber: Long): Boolean {
return watermarkCache[identity].getAndUpdate { maxOf(sequenceNumber, it) } >= sequenceNumber
return watermarkCache[identity]!!.getAndUpdate { maxOf(sequenceNumber, it) } >= sequenceNumber
}
}

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package net.corda.nodeapi.internal.serialization

import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import com.github.benmanes.caffeine.cache.Cache
import com.github.benmanes.caffeine.cache.Caffeine
import net.corda.core.contracts.Attachment
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.copyBytes
Expand Down Expand Up @@ -30,7 +30,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe
override val useCase: SerializationContext.UseCase,
override val encoding: SerializationEncoding?,
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : SerializationContext {
private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build()
private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = Caffeine.newBuilder().weakValues().maximumSize(1024).build()

/**
* {@inheritDoc}
Expand All @@ -49,7 +49,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe
}
missing.isNotEmpty() && throw MissingAttachmentsException(missing)
AttachmentsClassLoader(attachments, parent = deserializationClassLoader)
})
}!!)
} catch (e: ExecutionException) {
// Caught from within the cache get, so unwrap.
throw e.cause!!
Expand Down
3 changes: 3 additions & 0 deletions node/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ dependencies {

compile "com.google.guava:guava:$guava_version"

// For caches rather than guava
compile "com.github.ben-manes.caffeine:caffeine:$caffeine_version"

// JOpt: for command line flags.
compile "net.sf.jopt-simple:jopt-simple:$jopt_simple_version"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package net.corda.node.internal.security

import com.google.common.cache.CacheBuilder
import com.google.common.cache.Cache

import com.github.benmanes.caffeine.cache.Cache
import com.github.benmanes.caffeine.cache.Caffeine
import com.google.common.primitives.Ints
import net.corda.core.context.AuthServiceId
import net.corda.core.utilities.loggerFor
Expand Down Expand Up @@ -94,7 +95,7 @@ class RPCSecurityManagerImpl(config: AuthServiceConfig) : RPCSecurityManager {
return DefaultSecurityManager(realm).also {
// Setup optional cache layer if configured
it.cacheManager = config.options?.cache?.let {
GuavaCacheManager(
CaffeineCacheManager(
timeToLiveSeconds = it.expireAfterSecs,
maxSize = it.maxEntries)
}
Expand Down Expand Up @@ -257,9 +258,9 @@ private class NodeJdbcRealm(config: SecurityConfiguration.AuthService.DataSource
private typealias ShiroCache<K, V> = org.apache.shiro.cache.Cache<K, V>

/*
* Adapts a [com.google.common.cache.Cache] to a [org.apache.shiro.cache.Cache] implementation.
* Adapts a [com.github.benmanes.caffeine.cache.Cache] to a [org.apache.shiro.cache.Cache] implementation.
*/
private fun <K, V> Cache<K, V>.toShiroCache(name: String) = object : ShiroCache<K, V> {
private fun <K : Any, V> Cache<K, V>.toShiroCache(name: String) = object : ShiroCache<K, V> {

val name = name
private val impl = this@toShiroCache
Expand All @@ -282,37 +283,37 @@ private fun <K, V> Cache<K, V>.toShiroCache(name: String) = object : ShiroCache<
impl.invalidateAll()
}

override fun size() = Ints.checkedCast(impl.size())
override fun size() = Ints.checkedCast(impl.estimatedSize())
override fun keys() = impl.asMap().keys
override fun values() = impl.asMap().values
override fun toString() = "Guava cache adapter [$impl]"
}

/*
* Implementation of [org.apache.shiro.cache.CacheManager] based on
* cache implementation in [com.google.common.cache]
* cache implementation in [com.github.benmanes.caffeine.cache.Cache]
*/
private class GuavaCacheManager(val maxSize: Long,
val timeToLiveSeconds: Long) : CacheManager {
private class CaffeineCacheManager(val maxSize: Long,
val timeToLiveSeconds: Long) : CacheManager {

private val instances = ConcurrentHashMap<String, ShiroCache<*, *>>()

override fun <K, V> getCache(name: String): ShiroCache<K, V> {
override fun <K : Any, V> getCache(name: String): ShiroCache<K, V> {
val result = instances[name] ?: buildCache<K, V>(name)
instances.putIfAbsent(name, result)
return result as ShiroCache<K, V>
}

private fun <K, V> buildCache(name: String) : ShiroCache<K, V> {
private fun <K : Any, V> buildCache(name: String): ShiroCache<K, V> {
logger.info("Constructing cache '$name' with maximumSize=$maxSize, TTL=${timeToLiveSeconds}s")
return CacheBuilder.newBuilder()
return Caffeine.newBuilder()
.expireAfterWrite(timeToLiveSeconds, TimeUnit.SECONDS)
.maximumSize(maxSize)
.build<K, V>()
.toShiroCache(name)
}

companion object {
private val logger = loggerFor<GuavaCacheManager>()
private val logger = loggerFor<CaffeineCacheManager>()
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package net.corda.node.services.messaging

import co.paralleluniverse.common.util.SameThreadExecutor
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import com.google.common.cache.RemovalListener
import com.github.benmanes.caffeine.cache.Cache
import com.github.benmanes.caffeine.cache.Caffeine
import com.github.benmanes.caffeine.cache.RemovalListener
import com.google.common.collect.HashMultimap
import com.google.common.collect.Multimaps
import com.google.common.collect.SetMultimap
Expand Down Expand Up @@ -145,11 +146,11 @@ class RPCServer(
}

private fun createObservableSubscriptionMap(): ObservableSubscriptionMap {
val onObservableRemove = RemovalListener<InvocationId, ObservableSubscription> {
log.debug { "Unsubscribing from Observable with id ${it.key} because of ${it.cause}" }
it.value.subscription.unsubscribe()
val onObservableRemove = RemovalListener<InvocationId, ObservableSubscription> { key, value, cause ->
log.debug { "Unsubscribing from Observable with id ${key} because of ${cause}" }
value!!.subscription.unsubscribe()
}
return CacheBuilder.newBuilder().removalListener(onObservableRemove).build()
return Caffeine.newBuilder().removalListener(onObservableRemove).executor(SameThreadExecutor.getExecutor()).build()
}

fun start(activeMqServerControl: ActiveMQServerControl) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ open class PersistentNetworkMapCache(

override fun getNodesByLegalName(name: CordaX500Name): List<NodeInfo> = database.transaction { queryByLegalName(session, name) }

override fun getNodesByLegalIdentityKey(identityKey: PublicKey): List<NodeInfo> = nodesByKeyCache[identityKey]
override fun getNodesByLegalIdentityKey(identityKey: PublicKey): List<NodeInfo> = nodesByKeyCache[identityKey]!!

private val nodesByKeyCache = NonInvalidatingCache<PublicKey, List<NodeInfo>>(1024, 8, { key -> database.transaction { queryByIdentityKey(session, key) } })
private val nodesByKeyCache = NonInvalidatingCache<PublicKey, List<NodeInfo>>(1024, { key -> database.transaction { queryByIdentityKey(session, key) } })

override fun getNodesByOwningKeyIndex(identityKeyIndex: String): List<NodeInfo> {
return database.transaction {
Expand All @@ -176,9 +176,9 @@ open class PersistentNetworkMapCache(

override fun getNodeByAddress(address: NetworkHostAndPort): NodeInfo? = database.transaction { queryByAddress(session, address) }

override fun getPeerCertificateByLegalName(name: CordaX500Name): PartyAndCertificate? = identityByLegalNameCache.get(name).orElse(null)
override fun getPeerCertificateByLegalName(name: CordaX500Name): PartyAndCertificate? = identityByLegalNameCache.get(name)!!.orElse(null)

private val identityByLegalNameCache = NonInvalidatingCache<CordaX500Name, Optional<PartyAndCertificate>>(1024, 8, { name -> Optional.ofNullable(database.transaction { queryIdentityByLegalName(session, name) }) })
private val identityByLegalNameCache = NonInvalidatingCache<CordaX500Name, Optional<PartyAndCertificate>>(1024, { name -> Optional.ofNullable(database.transaction { queryIdentityByLegalName(session, name) }) })

override fun track(): DataFeed<List<NodeInfo>, MapChange> {
synchronized(_changed) {
Expand Down
Loading

0 comments on commit a24a210

Please sign in to comment.