Skip to content

Commit

Permalink
Merge pull request kubevirt#11663 from iholder101/generic_cache/follo…
Browse files Browse the repository at this point in the history
…w-up-better-locking

[time-defined cache] [follow-up]: Use a regular Mutex over RWMutex to lock the entire Get() function
  • Loading branch information
kubevirt-bot authored Apr 10, 2024
2 parents 3f46669 + 17eb529 commit a455114
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 46 deletions.
1 change: 1 addition & 0 deletions tools/cache/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
srcs = ["time-defined-cache.go"],
importpath = "kubevirt.io/kubevirt/tools/cache",
visibility = ["//visibility:public"],
deps = ["//pkg/pointer:go_default_library"],
)

go_test(
Expand Down
35 changes: 17 additions & 18 deletions tools/cache/time-defined-cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ import (
"fmt"
"sync"
"time"

k6tpointer "kubevirt.io/kubevirt/pkg/pointer"
)

type TimeDefinedCache[T any] struct {
minRefreshDuration time.Duration
lastRefresh time.Time
savedValueSet bool
savedValue T
lastRefresh *time.Time
value T
reCalcFunc func() (T, error)
valueLock *sync.RWMutex
valueLock *sync.Mutex
}

// NewTimeDefinedCache creates a new cache that will refresh the value every minRefreshDuration. If the value is requested
Expand All @@ -51,34 +52,30 @@ func NewTimeDefinedCache[T any](minRefreshDuration time.Duration, useValueLock b
}

if useValueLock {
t.valueLock = &sync.RWMutex{}
t.valueLock = &sync.Mutex{}
}

return t, nil
}

func (t *TimeDefinedCache[T]) Get() (T, error) {
if t.savedValueSet && t.minRefreshDuration.Nanoseconds() != 0 && time.Since(t.lastRefresh) <= t.minRefreshDuration {
if t.valueLock != nil {
t.valueLock.RLock()
defer t.valueLock.RUnlock()
}
return t.savedValue, nil
}

if t.valueLock != nil {
t.valueLock.Lock()
defer t.valueLock.Unlock()
}

if t.lastRefresh != nil && t.minRefreshDuration.Nanoseconds() != 0 && time.Since(*t.lastRefresh) <= t.minRefreshDuration {
return t.value, nil
}

value, err := t.reCalcFunc()
if err != nil {
return t.savedValue, err
return t.value, err
}

t.setWithoutLock(value)

return t.savedValue, nil
return t.value, nil
}

func (t *TimeDefinedCache[T]) Set(value T) {
Expand All @@ -91,7 +88,9 @@ func (t *TimeDefinedCache[T]) Set(value T) {
}

func (t *TimeDefinedCache[T]) setWithoutLock(value T) {
t.savedValue = value
t.savedValueSet = true
t.lastRefresh = time.Now()
t.value = value

if t.lastRefresh == nil || t.minRefreshDuration.Nanoseconds() != 0 {
t.lastRefresh = k6tpointer.P(time.Now())
}
}
49 changes: 21 additions & 28 deletions tools/cache/time-defined-cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package cache_test

import (
"fmt"
"sync/atomic"
"time"

Expand Down Expand Up @@ -71,49 +72,41 @@ var _ = Describe("time defined cache", func() {
Expect(err).To(HaveOccurred())
})

It("should not allow two threads to set value in parallel", func() {
stopChannel := make(chan struct{})
defer close(stopChannel)
firstCallMadeChannel := make(chan struct{})

recalcFunctionCalls := int64(0)
It("when multiple go routines get a value only one is recalculating and others get the same cached value", func() {
firstCallBarrier := make(chan struct{})
recalcFunctionCallsCount := int64(0)

recalcFunc := func() (int, error) {
firstCallMadeChannel <- struct{}{}
atomic.AddInt64(&recalcFunctionCalls, 1)

ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()

select {
case <-ticker.C:
time.Sleep(100 * time.Millisecond)
case <-stopChannel:
break
}

return 1, nil
atomic.AddInt64(&recalcFunctionCallsCount, 1)
firstCallBarrier <- struct{}{}
return int(recalcFunctionCallsCount), nil
}

cache, err := virtcache.NewTimeDefinedCache(0, true, recalcFunc)
cache, err := virtcache.NewTimeDefinedCache(time.Hour, true, recalcFunc)
Expect(err).ToNot(HaveOccurred())

const goroutineCount = 20
getReturnValues := make(chan int, goroutineCount)
getValueFromCache := func() {
defer GinkgoRecover()
_, err = cache.Get()
ret, err := cache.Get()
Expect(err).ShouldNot(HaveOccurred())
getReturnValues <- ret
}

for i := 0; i < 5; i++ {
for i := 0; i < goroutineCount; i++ {
go getValueFromCache()
}

// To ensure the first call is already made
<-firstCallMadeChannel
Consistently(getReturnValues).Should(BeEmpty(), "all go routines should wait for the first one to finish")
Eventually(firstCallBarrier).Should(Receive(), "first go routine should start re-calculating")
Eventually(getReturnValues).Should(HaveLen(goroutineCount), fmt.Sprintf("expected all go routines to finish calling Get(). %d/%d finished", len(getReturnValues), goroutineCount))

Consistently(func() {
Expect(recalcFunctionCalls).To(Equal(int64(1)), "value is being re-calculated, only one caller is expected")
}).WithPolling(250 * time.Millisecond).WithTimeout(1 * time.Second)
close(getReturnValues)
for getValue := range getReturnValues {
Expect(getValue).To(Equal(1), "Get() calls are expected to return the cached value")
}
Expect(firstCallBarrier).To(BeEmpty(), "ensure no other go routine called the re-calculation funtion")
})

})

0 comments on commit a455114

Please sign in to comment.