Skip to content

Commit

Permalink
Performance! No double lookups in GetOrAdd, faster Enumerator etc.
Browse files Browse the repository at this point in the history
Also more unit tests
  • Loading branch information
alex-jitbit committed Nov 4, 2024
1 parent a4b8631 commit d5cae59
Show file tree
Hide file tree
Showing 3 changed files with 348 additions and 44 deletions.
91 changes: 68 additions & 23 deletions FastCache/FastCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void EvictExpired()

foreach (var p in _dict)
{
if (currTime > p.Value.TickCountWhenToKill) //instead of calling "p.Value.IsExpired" we're essentially doing the same thing manually
if (p.Value.IsExpired(currTime)) //call IsExpired with "currTime" to avoid calling Environment.TickCount64 multiple times
_dict.TryRemove(p);
}
}
Expand Down Expand Up @@ -181,14 +181,23 @@ public bool TryAdd(TKey key, TValue value, TimeSpan ttl)
/// <param name="ttl">TTL of the item</param>
public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory, TimeSpan ttl)
{
if (TryGet(key, out var value))
return value;

return _dict.GetOrAdd(
bool wasAdded = false; //flag to indicate "add vs get". TODO: wrap in ref type some day to avoid captures/closures
var ttlValue = _dict.GetOrAdd(
key,
(k, arg) => new TtlValue(arg.valueFactory(k), arg.ttl),
(ttl, valueFactory)
).Value;
(k) =>
{
wasAdded = true;
return new TtlValue(valueFactory(k), ttl);
});

//if the item is expired, update value and TTL
//since TtlValue is a reference type we can update its properties in-place, instead of removing and re-adding to the dictionary (extra lookups)
if (!wasAdded) //performance hack: skip expiration check if a brand item was just added
{
ttlValue.ModifyIfExpired(() => valueFactory(key), ttl);
}

return ttlValue.Value;
}

/// <summary>
Expand All @@ -200,14 +209,23 @@ public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory, TimeSpan ttl)
/// <param name="factoryArgument">Argument value to pass into valueFactory</param>
public TValue GetOrAdd<TArg>(TKey key, Func<TKey, TArg, TValue> valueFactory, TimeSpan ttl, TArg factoryArgument)
{
if (TryGet(key, out var value))
return value;

return _dict.GetOrAdd(
bool wasAdded = false; //flag to indicate "add vs get"
var ttlValue = _dict.GetOrAdd(
key,
(k, arg) => new TtlValue(arg.valueFactory(k, arg.factoryArgument), arg.ttl),
(ttl, valueFactory, factoryArgument)
).Value;
(k) =>
{
wasAdded = true;
return new TtlValue(valueFactory(k, factoryArgument), ttl);
});

//if the item is expired, update value and TTL
//since TtlValue is a reference type we can update its properties in-place, instead of removing and re-adding to the dictionary (extra lookups)
if (!wasAdded) //performance hack: skip expiration check if a brand item was just added
{
ttlValue.ModifyIfExpired(() => valueFactory(key, factoryArgument), ttl);
}

return ttlValue.Value;
}

/// <summary>
Expand All @@ -218,10 +236,22 @@ public TValue GetOrAdd<TArg>(TKey key, Func<TKey, TArg, TValue> valueFactory, Ti
/// <param name="ttl">TTL of the item</param>
public TValue GetOrAdd(TKey key, TValue value, TimeSpan ttl)
{
if (TryGet(key, out var existingValue))
return existingValue;
bool wasAdded = false; //flag to indicate "add vs get"
var ttlValue = _dict.GetOrAdd(key,
(k) =>
{
wasAdded = true;
return new TtlValue(value, ttl);
});

//if the item is expired, update value and TTL
//since TtlValue is a reference type we can update its properties in-place, instead of removing and re-adding to the dictionary (extra lookups)
if (!wasAdded) //performance hack: skip expiration check if a brand item was just added
{
ttlValue.ModifyIfExpired(() => value, ttl);
}

return _dict.GetOrAdd(key, new TtlValue(value, ttl)).Value;
return ttlValue.Value;
}

/// <summary>
Expand All @@ -245,11 +275,13 @@ public bool TryRemove(TKey key, out TValue value)
return res;
}

/// <inheritdoc/>
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
var currTime = Environment.TickCount64; //save to a var to prevent multiple calls to Environment.TickCount64
foreach (var kvp in _dict)
{
if (!kvp.Value.IsExpired())
if (!kvp.Value.IsExpired(currTime))
yield return new KeyValuePair<TKey, TValue>(kvp.Key, kvp.Value.Value);
}
}
Expand All @@ -261,18 +293,31 @@ IEnumerator IEnumerable.GetEnumerator()

private class TtlValue
{
public readonly TValue Value;
public readonly long TickCountWhenToKill;
public TValue Value { get; private set; }
private long TickCountWhenToKill;

public TtlValue(TValue value, TimeSpan ttl)
{
Value = value;
TickCountWhenToKill = Environment.TickCount64 + (long)ttl.TotalMilliseconds;
}

public bool IsExpired()
public bool IsExpired() => IsExpired(Environment.TickCount64);

//use an overload instead of optional param to avoid extra IF's
public bool IsExpired(long currTime) => currTime > TickCountWhenToKill;

/// <summary>
/// Updates the value and TTL only if the item is expired
/// </summary>
public void ModifyIfExpired(Func<TValue> newValueFactory, TimeSpan newTtl)
{
return Environment.TickCount64 > TickCountWhenToKill;
var ticks = Environment.TickCount64; //save to a var to prevent multiple calls to Environment.TickCount64
if (IsExpired(ticks)) //if expired - update the value and TTL
{
TickCountWhenToKill = ticks + (long)newTtl.TotalMilliseconds; //update the expiration time first for better concurrency
Value = newValueFactory();
}
}
}

Expand Down
72 changes: 51 additions & 21 deletions UnitTests/UnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class UnitTests
[TestMethod]
public async Task TestGetSetCleanup()
{
var _cache = new FastCache<int, int>(cleanupJobInterval: 200);
using var _cache = new FastCache<int, int>(cleanupJobInterval: 200); //add "using" to stop cleanup timer, to prevent cleanup job from clashing with other tests
_cache.AddOrUpdate(42, 42, TimeSpan.FromMilliseconds(100));
Assert.IsTrue(_cache.TryGet(42, out int v));
Assert.IsTrue(v == 42);
Expand All @@ -22,24 +22,24 @@ public async Task TestGetSetCleanup()
public async Task TestEviction()
{
var list = new List<FastCache<int, int>>();
for (int i = 0; i < 20; i++)
{
var cache = new FastCache<int, int>(cleanupJobInterval: 200);
cache.AddOrUpdate(42, 42, TimeSpan.FromMilliseconds(100));
list.Add(cache);
for (int i = 0; i < 20; i++)
{
var cache = new FastCache<int, int>(cleanupJobInterval: 200);
cache.AddOrUpdate(42, 42, TimeSpan.FromMilliseconds(100));
list.Add(cache);
}
await Task.Delay(300);

for (int i = 0; i < 20; i++)
{
Assert.IsTrue(list[i].Count == 0); //cleanup job has run?
}

//cleanup
for (int i = 0; i < 20; i++)
{
list[i].Dispose();
}
for (int i = 0; i < 20; i++)
{
Assert.IsTrue(list[i].Count == 0); //cleanup job has run?
}

//cleanup
for (int i = 0; i < 20; i++)
{
list[i].Dispose();
}
}

[TestMethod]
Expand Down Expand Up @@ -80,9 +80,9 @@ public void TestTryRemove()
cache.AddOrUpdate("42", 42, TimeSpan.FromMilliseconds(100));
var res = cache.TryRemove("42", out int value);
Assert.IsTrue(res && value == 42);
Assert.IsFalse(cache.TryGet("42", out _));

//now try remove non-existing item
Assert.IsFalse(cache.TryGet("42", out _));

//now try remove non-existing item
res = cache.TryRemove("blabblah", out value);
Assert.IsFalse(res);
Assert.IsTrue(value == 0);
Expand Down Expand Up @@ -117,10 +117,33 @@ public async Task TestGetOrAdd()
{
var cache = new FastCache<string, int>();
cache.GetOrAdd("key", k => 1024, TimeSpan.FromMilliseconds(100));
Assert.IsTrue(cache.TryGet("key", out int res) && res == 1024);
Assert.AreEqual(cache.GetOrAdd("key", k => 1025, TimeSpan.FromMilliseconds(100)), 1024); //old value
Assert.IsTrue(cache.TryGet("key", out int res) && res == 1024); //another way to retrieve
await Task.Delay(110);

Assert.IsFalse(cache.TryGet("key", out _));
Assert.IsFalse(cache.TryGet("key", out _)); //expired

//now try non-factory overloads
Assert.IsTrue(cache.GetOrAdd("key123", 123321, TimeSpan.FromMilliseconds(100)) == 123321);
Assert.IsTrue(cache.GetOrAdd("key123", -1, TimeSpan.FromMilliseconds(100)) == 123321); //still old value
await Task.Delay(110);
Assert.IsTrue(cache.GetOrAdd("key123", -1, TimeSpan.FromMilliseconds(100)) == -1); //new value
}


[TestMethod]
public async Task TestGetOrAddExpiration()
{
var cache = new FastCache<string, int>();
cache.GetOrAdd("key", k => 1024, TimeSpan.FromMilliseconds(100));

Assert.AreEqual(cache.GetOrAdd("key", k => 1025, TimeSpan.FromMilliseconds(100)), 1024); //old value
Assert.IsTrue(cache.TryGet("key", out int res) && res == 1024); //another way to retrieve

await Task.Delay(110); //let the item expire

Assert.AreEqual(cache.GetOrAdd("key", k => 1025, TimeSpan.FromMilliseconds(100)), 1025); //new value
Assert.IsTrue(cache.TryGet("key", out res) && res == 1025); //another way to retrieve
}

[TestMethod]
Expand All @@ -133,6 +156,12 @@ public async Task TestGetOrAddWithArg()
//eviction
await Task.Delay(110);
Assert.IsFalse(cache.TryGet("key", out _));

//now try without "TryGet"
Assert.IsTrue(cache.GetOrAdd("key2", (k, arg) => 21 + arg.Length, TimeSpan.FromMilliseconds(100), "123") == 24);
Assert.IsTrue(cache.GetOrAdd("key2", (k, arg) => 2211 + arg.Length, TimeSpan.FromMilliseconds(100), "123") == 24);
await Task.Delay(110);
Assert.IsTrue(cache.GetOrAdd("key2", (k, arg) => 2211 + arg.Length, TimeSpan.FromMilliseconds(100), "123") == 2214);
}

[TestMethod]
Expand Down Expand Up @@ -164,6 +193,7 @@ await TestHelper.RunConcurrently(20, () => {
Assert.IsTrue(i == 1, i.ToString());
}

//this text can occasionally fail becasue factory is not guaranteed to be called only once. only panic if it fails ALL THE TIME
[TestMethod]
public async Task TestGetOrAddAtomicNess()
{
Expand Down
Loading

0 comments on commit d5cae59

Please sign in to comment.