Skip to content

Commit

Permalink
use SimpleMemCache instead UdpNat. Fix UDP leak!
Browse files Browse the repository at this point in the history
  • Loading branch information
trudyhood committed Dec 22, 2021
1 parent 736b973 commit 422760b
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* Fix: Limit thread count that prevent server from responding in high load
* Fix: Memory leak! Some dead sessions remain in memory
* Fix: Memory leak! TcpProxy remains in memory when just one peer has gone
* Fix: Memory leak! UdpProxy remains in memory
* Fix: Unusual Thread creating
* Fix: UDP Packet loss

Expand Down
43 changes: 14 additions & 29 deletions VpnHood.Tunneling/ProxyManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ public abstract class ProxyManager : IDisposable
private bool _disposed;
private readonly HashSet<IChannel> _channels = new();
private readonly PingProxyPool _pingProxyPool = new();
private readonly Nat _udpNat;
private readonly SimpleMemCache<string, UdpProxy> _udpProxies = new(true);

public int MaxUdpPortCount { get; set; } = 0;

// override Handle UdpProxy.OnPacketReceived
Expand All @@ -44,18 +45,16 @@ public override Task OnPacketReceived(IPPacket ipPacket)

protected ProxyManager()
{
_udpNat = new Nat(false);
_udpNat.OnNatItemRemoved += Nat_OnNatItemRemoved;
}

public void Cleanup()
{
_udpNat.Cleanup();
_udpProxies.Cleanup();
}

public TimeSpan UdpTimeout { get => _udpNat.UdpTimeout; set => _udpNat.UdpTimeout = value; }
public TimeSpan? UdpTimeout { get => _udpProxies.Timeout; set => _udpProxies.Timeout = value; }

public int UdpConnectionCount => _udpNat.ItemCount;
public int UdpConnectionCount => _udpProxies.Count;

// ReSharper disable once UnusedMember.Global
public int TcpConnectionCount
Expand All @@ -71,14 +70,6 @@ public int TcpConnectionCount
protected abstract Task OnPacketReceived(IPPacket ipPacket);
protected abstract bool IsPingSupported { get; }

private void Nat_OnNatItemRemoved(object sender, NatEventArgs e)
{
if (e.NatItem.Tag is UdpProxy udpProxy)
udpProxy.Dispose();
else
VhLogger.Instance.LogWarning($"@Error: oops! no udpProxy on tag"); //todo
}

public virtual void SendPacket(IPPacket[] ipPackets)
{
foreach (var ipPacket in ipPackets)
Expand Down Expand Up @@ -146,24 +137,19 @@ private void SendUdpPacket(IPPacket ipPacket)
return;

// send packet via proxy
var natItem = _udpNat.Get(ipPacket);
if (natItem?.Tag is not UdpProxy udpProxy || udpProxy.IsDisposed)
var udpPacket = PacketUtil.ExtractUdp(ipPacket);
var udpKey = $"{ipPacket.SourceAddress}:{udpPacket.SourcePort}";
if (!_udpProxies.TryGetValue(udpKey, out var udpProxy) || udpProxy.IsDisposed)
{
var udpCount = _udpNat.ItemCount;
if (MaxUdpPortCount != 0 && udpCount > MaxUdpPortCount)
if (MaxUdpPortCount != 0 && _udpProxies.Count > MaxUdpPortCount)
{
VhLogger.Instance.LogWarning(GeneralEventId.Udp, $"Too many UDP port! Killing the oldest UdpProxy. {nameof(MaxUdpPortCount)}: {MaxUdpPortCount}");
_udpNat.RemoveOldest(ProtocolType.Udp);
VhLogger.Instance.LogWarning(GeneralEventId.Udp, $"Too many UDP ports! Killing the oldest UdpProxy. {nameof(MaxUdpPortCount)}: {MaxUdpPortCount}");
_udpProxies.RemoveOldest();
}

var udpPacket = PacketUtil.ExtractUdp(ipPacket);
udpProxy = new MyUdpProxy(this, CreateUdpClient(ipPacket.SourceAddress.AddressFamily), new IPEndPoint(ipPacket.SourceAddress, udpPacket.SourcePort));
try
{
natItem = _udpNat.Add(ipPacket, (ushort)udpProxy.LocalPort, true);
natItem.Tag = udpProxy;
}
catch { udpProxy.Dispose(); throw; }
if (!_udpProxies.TryAdd(udpKey, udpProxy, true))
udpProxy.Dispose();
}

udpProxy.Send(ipPacket);
Expand Down Expand Up @@ -191,8 +177,7 @@ public void Dispose()
if (_disposed) return;
_disposed = true;

_udpNat.Dispose();
_udpNat.OnNatItemRemoved -= Nat_OnNatItemRemoved; //must be after Nat.dispose
_udpProxies.Dispose();

// dispose channels
IChannel[] channels;
Expand Down
154 changes: 154 additions & 0 deletions VpnHood.Tunneling/SimpleMemCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
using System;
using System.Collections.Concurrent;
using System.Linq;

namespace VpnHood.Tunneling
{
public class SimpleMemCache<TKey, TValue> : IDisposable
{
private readonly ConcurrentDictionary<TKey, SimpleItem<TValue>> _items = new();
private readonly bool _autoDisposeItem;
private DateTime _lastCleanup = DateTime.MinValue;
private bool _disposed;

public TimeSpan? Timeout { get; set; }

public SimpleMemCache(bool autoDisposeItem)
{
_autoDisposeItem = autoDisposeItem;
}

public int Count
{
get
{
Cleanup();
return _items.Count;
}
}

public bool TryGetValue(TKey key, out TValue value)
{
Cleanup();

// return false if not exists
if (!_items.TryGetValue(key, out var itemValue))
{
value = default!;
return false;
}

// return fakse if expired
if (IsExpired(itemValue))
{
value = default!;
TryRemove(key, out _);
return false;
}

// return item
itemValue.AccessedTime = DateTime.Now;
value = itemValue.Value;
return false;
}

public bool TryAdd(TKey key, TValue value, bool overwride)
{
Cleanup();

// return true if added
if (_items.TryAdd(key, new SimpleItem<TValue>(value)))
return true;

// remove and rety if overwrite is on
if (overwride)
{
TryRemove(key, out _);
return _items.TryAdd(key, new SimpleItem<TValue>(value));
}

// remove & retry of item has been expired
if (_items.TryGetValue(key, out var itemValue) && IsExpired(itemValue))
{
TryRemove(key, out _);
return _items.TryAdd(key, new SimpleItem<TValue>(value));
}

// couldn't add
return false;
}

public bool TryRemove(TKey key, out TValue value)
{
// try add
var ret = _items.TryRemove(key, out var itemValue);
if (ret && _autoDisposeItem)
((IDisposable)itemValue).Dispose();

value = itemValue.Value;

return ret;
}

private bool IsExpired(SimpleItem<TValue> item)
{
return Timeout != null && DateTime.Now - item.AccessedTime > Timeout;
}

private class SimpleItem<T>
{
public DateTime AccessedTime { get; set; }
public T Value { get; set; }
public SimpleItem(T value)
{
Value = value;
}
}

public void Cleanup(bool force = false)
{
// do nothing if there is not timeout
if (Timeout == null)
return;

// return if already checked
if (!force && DateTime.Now - _lastCleanup > Timeout / 3)
return;
_lastCleanup = DateTime.Now;

// remove timeout items
foreach (var item in _items.Where(x => IsExpired(x.Value)))
TryRemove(item.Key, out _);
}

public void Dispose()
{
if (_disposed) return;
_disposed = true;

if (_autoDisposeItem)
{
foreach (var itemValue in _items.Values)
((IDisposable)itemValue.Value!)?.Dispose();
}
_items.Clear();
}

public void RemoveOldest()
{
var oldestAccessedTime = DateTime.MaxValue;
var oldestKey = default(TKey?);
foreach (var item in _items)
{
if (oldestAccessedTime < item.Value.AccessedTime)
{
oldestAccessedTime = item.Value.AccessedTime;
oldestKey = item.Key;
}
}

if (oldestKey != null)
TryRemove(oldestKey, out _);
}
}
}
9 changes: 7 additions & 2 deletions VpnHood.ZTest/Tests/NatTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,18 @@ public void Nat_NatItem_Test()
[TestMethod]
public void Nat_NatItemEx_Test()
{
var nat = new Nat(true);

var ipPacket = PacketUtil.CreateIpPacket(IPAddress.Parse("10.1.1.1"), IPAddress.Parse("10.1.1.2"));
var tcpPacket = new TcpPacket(100, 100);
ipPacket.PayloadPacket = tcpPacket;

var nat = new Nat(true);
var id = nat.Add(ipPacket).NatId;

var ipPacket2 = PacketUtil.CreateIpPacket(IPAddress.Parse("10.1.1.1"), IPAddress.Parse("10.1.1.2"));
var tcpPacket2 = new TcpPacket(101, 100);
ipPacket2.PayloadPacket = tcpPacket2;
nat.Add(ipPacket2);

// un-map
var natItem = (NatItemEx?) nat.Resolve(ipPacket.Version, ProtocolType.Tcp, id);
Assert.IsNotNull(natItem);
Expand Down

0 comments on commit 422760b

Please sign in to comment.