Skip to content

Commit

Permalink
Remove a couple allocations from Socket.Connect/Bind/etc. (dotnet#32271)
Browse files Browse the repository at this point in the history
The Socket implementation calls the internal SnapshotAndSerialize method on a bunch of code paths, like Connect, Bind, etc.  The "snapshot" part of the name comes from the time of CAS, and the implementation needed to clone the instance to make security decisions.  We no longer make such decisions, but we're still cloning the objects.  We can stop doing that.
  • Loading branch information
stephentoub authored Feb 14, 2020
1 parent 0d08b02 commit 8f546c9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ public static EndPoint Create(this EndPoint thisObj, Internals.SocketAddress soc
return thisObj.Create(address);
}

internal static IPEndPoint Snapshot(this IPEndPoint thisObj)
{
return new IPEndPoint(thisObj.Address.Snapshot(), thisObj.Port);
}

private static Internals.SocketAddress GetInternalSocketAddress(System.Net.SocketAddress address)
{
var result = new Internals.SocketAddress(address.Family, address.Size);
Expand Down
74 changes: 29 additions & 45 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,8 @@ public void Bind(EndPoint localEP)

if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"localEP:{localEP}");

// Ask the EndPoint to generate a SocketAddress that we can pass down to native code.
EndPoint endPointSnapshot = localEP;
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref endPointSnapshot);

DoBind(endPointSnapshot, socketAddress);
Internals.SocketAddress socketAddress = Serialize(ref localEP);
DoBind(localEP, socketAddress);

if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
}
Expand Down Expand Up @@ -766,16 +763,15 @@ public void Connect(EndPoint remoteEP)

ValidateForMultiConnect(isMultiEndpoint: false);

EndPoint endPointSnapshot = remoteEP;
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
Internals.SocketAddress socketAddress = Serialize(ref remoteEP);

if (!Blocking)
{
_nonBlockingConnectRightEndPoint = endPointSnapshot;
_nonBlockingConnectRightEndPoint = remoteEP;
_nonBlockingConnectInProgress = true;
}

DoConnect(endPointSnapshot, socketAddress);
DoConnect(remoteEP, socketAddress);
}

public void Connect(IPAddress address, int port)
Expand Down Expand Up @@ -1274,8 +1270,7 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags,
ValidateBlockingMode();
if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"SRC:{LocalEndPoint} size:{size} remoteEP:{remoteEP}");

EndPoint endPointSnapshot = remoteEP;
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
Internals.SocketAddress socketAddress = Serialize(ref remoteEP);

int bytesTransferred;
SocketError errorCode = SocketPal.SendTo(_handle, buffer, offset, size, socketFlags, socketAddress.Buffer, socketAddress.Size, out bytesTransferred);
Expand All @@ -1291,7 +1286,7 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags,
if (_rightEndPoint == null)
{
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint = endPointSnapshot;
_rightEndPoint = remoteEP;
}

if (NetEventSource.IsEnabled)
Expand Down Expand Up @@ -1551,7 +1546,7 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla
// WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress
// with the right address family.
EndPoint endPointSnapshot = remoteEP;
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot);

// Save a copy of the original EndPoint.
Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot);
Expand Down Expand Up @@ -1633,7 +1628,7 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl
// WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress
// with the right address family.
EndPoint endPointSnapshot = remoteEP;
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot);
Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot);

int bytesTransferred;
Expand Down Expand Up @@ -2669,15 +2664,14 @@ public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags
throw new ArgumentOutOfRangeException(nameof(size));
}

EndPoint endPointSnapshot = remoteEP;
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
Internals.SocketAddress socketAddress = Serialize(ref remoteEP);

// Set up the async result and indicate to flow the context.
OverlappedAsyncResult asyncResult = new OverlappedAsyncResult(this, state, callback);
asyncResult.StartPostingAsyncOp(false);

// Post the send.
DoBeginSendTo(buffer, offset, size, socketFlags, endPointSnapshot, socketAddress, asyncResult);
DoBeginSendTo(buffer, offset, size, socketFlags, remoteEP, socketAddress, asyncResult);

// Finish, possibly posting the callback. The callback won't be posted before this point is reached.
asyncResult.FinishPostingAsyncOp(ref Caches.SendClosureCache);
Expand Down Expand Up @@ -3081,7 +3075,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size,
// We don't do a CAS demand here because the contents of remoteEP aren't used by
// WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress
// with the right address family
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref remoteEP);
Internals.SocketAddress socketAddress = Serialize(ref remoteEP);

// Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to
// avoid a Socket leak in case of error.
Expand Down Expand Up @@ -3181,7 +3175,7 @@ public int EndReceiveMessageFrom(IAsyncResult asyncResult, ref SocketFlags socke
throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndReceiveMessageFrom"));
}

Internals.SocketAddress socketAddressOriginal = SnapshotAndSerialize(ref endPoint);
Internals.SocketAddress socketAddressOriginal = Serialize(ref endPoint);

int bytesTransferred = castedAsyncResult.InternalWaitForCompletionInt32Result();
castedAsyncResult.EndCalled = true;
Expand Down Expand Up @@ -3278,7 +3272,7 @@ public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, Socket
// We don't do a CAS demand here because the contents of remoteEP aren't used by
// WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress
// with the right address family
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref remoteEP);
Internals.SocketAddress socketAddress = Serialize(ref remoteEP);

// Set up the result and set it to collect the context.
var asyncResult = new OriginalAddressOverlappedAsyncResult(this, state, callback);
Expand Down Expand Up @@ -3390,7 +3384,7 @@ public int EndReceiveFrom(IAsyncResult asyncResult, ref EndPoint endPoint)
throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndReceiveFrom"));
}

Internals.SocketAddress socketAddressOriginal = SnapshotAndSerialize(ref endPoint);
Internals.SocketAddress socketAddressOriginal = Serialize(ref endPoint);

int bytesTransferred = castedAsyncResult.InternalWaitForCompletionInt32Result();
castedAsyncResult.EndCalled = true;
Expand Down Expand Up @@ -3742,7 +3736,7 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket)
throw new NotSupportedException(SR.net_invalidversion);
}

e._socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
e._socketAddress = Serialize(ref endPointSnapshot);

WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily);

Expand Down Expand Up @@ -3935,7 +3929,7 @@ public bool ReceiveFromAsync(SocketAsyncEventArgs e)
// WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress
// with the right address family.
EndPoint endPointSnapshot = e.RemoteEndPoint;
e._socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
e._socketAddress = Serialize(ref endPointSnapshot);

// DualMode sockets may have updated the endPointSnapshot, and it has to have the same AddressFamily as
// e.m_SocketAddres for Create to work later.
Expand Down Expand Up @@ -3985,7 +3979,7 @@ public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e)
// WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress
// with the right address family.
EndPoint endPointSnapshot = e.RemoteEndPoint;
e._socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
e._socketAddress = Serialize(ref endPointSnapshot);

// DualMode may have updated the endPointSnapshot, and it has to have the same AddressFamily as
// e.m_SocketAddres for Create to work later.
Expand Down Expand Up @@ -4099,7 +4093,7 @@ public bool SendToAsync(SocketAsyncEventArgs e)

// Prepare SocketAddress
EndPoint endPointSnapshot = e.RemoteEndPoint;
e._socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
e._socketAddress = Serialize(ref endPointSnapshot);

// Prepare for and make the native call.
e.StartOperationCommon(this, SocketAsyncOperation.SendTo);
Expand Down Expand Up @@ -4182,15 +4176,16 @@ internal static int GetAddressSize(EndPoint endPoint)
endPoint.Serialize().Size;
}

private Internals.SocketAddress SnapshotAndSerialize(ref EndPoint remoteEP)
private Internals.SocketAddress Serialize(ref EndPoint remoteEP)
{
if (remoteEP is IPEndPoint ipSnapshot)
if (remoteEP is IPEndPoint ip)
{
// Snapshot to avoid external tampering and malicious derivations if IPEndPoint.
ipSnapshot = ipSnapshot.Snapshot();

// DualMode: return an IPEndPoint mapped to an IPv6 address.
remoteEP = RemapIPEndPoint(ipSnapshot);
IPAddress addr = ip.Address;
if (addr.AddressFamily == AddressFamily.InterNetwork && IsDualMode)
{
addr = addr.MapToIPv6(); // For DualMode, use an IPv6 address.
remoteEP = new IPEndPoint(addr, ip.Port);
}
}
else if (remoteEP is DnsEndPoint)
{
Expand All @@ -4200,17 +4195,6 @@ private Internals.SocketAddress SnapshotAndSerialize(ref EndPoint remoteEP)
return IPEndPointExtensions.Serialize(remoteEP);
}


// DualMode: automatically re-map IPv4 addresses to IPv6 addresses.
private IPEndPoint RemapIPEndPoint(IPEndPoint input)
{
if (input.AddressFamily == AddressFamily.InterNetwork && IsDualMode)
{
return new IPEndPoint(input.Address.MapToIPv6(), input.Port);
}
return input;
}

internal static void InitializeSockets()
{
if (!s_initialized)
Expand Down Expand Up @@ -4638,7 +4622,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa
if (NetEventSource.IsEnabled) NetEventSource.Enter(this);

EndPoint endPointSnapshot = remoteEP;
Internals.SocketAddress socketAddress = SnapshotAndSerialize(ref endPointSnapshot);
Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot);

WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily);

Expand Down Expand Up @@ -4807,7 +4791,7 @@ private static object PostOneBeginConnect(MultipleAddressConnectAsyncResult cont
{
EndPoint endPoint = new IPEndPoint(currentAddressSnapshot, context._port);

context._socket.SnapshotAndSerialize(ref endPoint);
context._socket.Serialize(ref endPoint);

IAsyncResult connectResult = context._socket.UnsafeBeginConnect(endPoint, CachedMultipleAddressConnectCallback, context);
if (connectResult.CompletedSynchronously)
Expand Down

0 comments on commit 8f546c9

Please sign in to comment.