Skip to content

Commit

Permalink
Socket.Select: increase ref count while the handle is in use (dotnet/…
Browse files Browse the repository at this point in the history
…corefx#41763)

* Socket.Select: increase ref count while the handle is in use

* PR feedback

* Use Socket.InternalSafeHandle for ref/release.

* Socket.Windows: SelectFileDescriptor: add missing unrefs


Commit migrated from dotnet/corefx@6c62c7f
  • Loading branch information
tmds authored and stephentoub committed Oct 23, 2019
1 parent 2aaf172 commit 5773d1c
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ internal unsafe bool TransmitPackets(SafeSocketHandle socketHandle, IntPtr packe
return transmitPackets(socketHandle, packetArray, elementCount, sendSize, overlapped, flags);
}

internal static void SocketListToFileDescriptorSet(IList socketList, Span<IntPtr> fileDescriptorSet)
internal static void SocketListToFileDescriptorSet(IList socketList, Span<IntPtr> fileDescriptorSet, ref int refsAdded)
{
int count;
if (socketList == null || (count = socketList.Count) == 0)
Expand All @@ -166,18 +166,21 @@ internal static void SocketListToFileDescriptorSet(IList socketList, Span<IntPtr
fileDescriptorSet[0] = (IntPtr)count;
for (int current = 0; current < count; current++)
{
if (!(socketList[current] is Socket))
if (!(socketList[current] is Socket socket))
{
throw new ArgumentException(SR.Format(SR.net_sockets_select, socketList[current].GetType().FullName, typeof(System.Net.Sockets.Socket).FullName), nameof(socketList));
}

fileDescriptorSet[current + 1] = ((Socket)socketList[current])._handle.DangerousGetHandle();
bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
fileDescriptorSet[current + 1] = socket.InternalSafeHandle.DangerousGetHandle();
refsAdded++;
}
}

// Transform the list socketList such that the only sockets left are those
// with a file descriptor contained in the array "fileDescriptorArray".
internal static void SelectFileDescriptor(IList socketList, Span<IntPtr> fileDescriptorSet)
internal static void SelectFileDescriptor(IList socketList, Span<IntPtr> fileDescriptorSet, ref int refsAdded)
{
// Walk the list in order.
//
Expand All @@ -195,6 +198,9 @@ internal static void SelectFileDescriptor(IList socketList, Span<IntPtr> fileDes
int returnedCount = (int)fileDescriptorSet[0];
if (returnedCount == 0)
{
// Unref safehandles.
SocketListDangerousReleaseRefs(socketList, ref refsAdded);

// No socket present, will never find any socket, remove them all.
socketList.Clear();
return;
Expand All @@ -219,6 +225,8 @@ internal static void SelectFileDescriptor(IList socketList, Span<IntPtr> fileDes
if (currentFileDescriptor == returnedCount)
{
// Descriptor not found: remove the current socket and start again.
socket.InternalSafeHandle.DangerousRelease();
refsAdded--;
socketList.RemoveAt(currentSocket--);
count--;
}
Expand Down
15 changes: 15 additions & 0 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5070,6 +5070,21 @@ private void ThrowIfDisposed()

private bool IsConnectionOriented => _socketType == SocketType.Stream;

internal static void SocketListDangerousReleaseRefs(IList socketList, ref int refsAdded)
{
if (socketList == null)
{
return;
}

for (int i = 0; (i < socketList.Count) && (refsAdded > 0); i++)
{
Socket socket = (Socket)socketList[i];
socket.InternalSafeHandle.DangerousRelease();
refsAdded--;
}
}

#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1444,37 +1444,58 @@ private static unsafe SocketError SelectViaPoll(
// Add each of the list's contents to the events array
Debug.Assert(eventsLength == checkReadInitialCount + checkWriteInitialCount + checkErrorInitialCount, "Invalid eventsLength");
int offset = 0;
AddToPollArray(events, eventsLength, checkRead, ref offset, Interop.Sys.PollEvents.POLLIN | Interop.Sys.PollEvents.POLLHUP);
AddToPollArray(events, eventsLength, checkWrite, ref offset, Interop.Sys.PollEvents.POLLOUT);
AddToPollArray(events, eventsLength, checkError, ref offset, Interop.Sys.PollEvents.POLLPRI);
Debug.Assert(offset == eventsLength, $"Invalid adds. offset={offset}, eventsLength={eventsLength}.");

// Do the poll
uint triggered = 0;
int milliseconds = microseconds == -1 ? -1 : microseconds / 1000;
Interop.Error err = Interop.Sys.Poll(events, (uint)eventsLength, milliseconds, &triggered);
if (err != Interop.Error.SUCCESS)
int refsAdded = 0;
try
{
return GetSocketErrorForErrorCode(err);
}
// In case we can't increase the reference count for each Socket,
// we'll unref refAdded Sockets in the finally block ordered: [checkRead, checkWrite, checkError].
AddToPollArray(events, eventsLength, checkRead, ref offset, Interop.Sys.PollEvents.POLLIN | Interop.Sys.PollEvents.POLLHUP, ref refsAdded);
AddToPollArray(events, eventsLength, checkWrite, ref offset, Interop.Sys.PollEvents.POLLOUT, ref refsAdded);
AddToPollArray(events, eventsLength, checkError, ref offset, Interop.Sys.PollEvents.POLLPRI, ref refsAdded);
Debug.Assert(offset == eventsLength, $"Invalid adds. offset={offset}, eventsLength={eventsLength}.");
Debug.Assert(refsAdded == eventsLength, $"Invalid ref adds. refsAdded={refsAdded}, eventsLength={eventsLength}.");

// Do the poll
uint triggered = 0;
int milliseconds = microseconds == -1 ? -1 : microseconds / 1000;
Interop.Error err = Interop.Sys.Poll(events, (uint)eventsLength, milliseconds, &triggered);
if (err != Interop.Error.SUCCESS)
{
return GetSocketErrorForErrorCode(err);
}

// Remove from the lists any entries which weren't set
if (triggered == 0)
{
checkRead?.Clear();
checkWrite?.Clear();
checkError?.Clear();
// Remove from the lists any entries which weren't set
if (triggered == 0)
{
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);

checkRead?.Clear();
checkWrite?.Clear();
checkError?.Clear();
}
else
{
FilterPollList(checkRead, events, checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLIN | Interop.Sys.PollEvents.POLLHUP, ref refsAdded);
FilterPollList(checkWrite, events, checkWriteInitialCount + checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLOUT, ref refsAdded);
FilterPollList(checkError, events, checkErrorInitialCount + checkWriteInitialCount + checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLERR | Interop.Sys.PollEvents.POLLPRI, ref refsAdded);
}

return SocketError.Success;
}
else
finally
{
FilterPollList(checkRead, events, checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLIN | Interop.Sys.PollEvents.POLLHUP);
FilterPollList(checkWrite, events, checkWriteInitialCount + checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLOUT);
FilterPollList(checkError, events, checkErrorInitialCount + checkWriteInitialCount + checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLERR | Interop.Sys.PollEvents.POLLPRI);
// This order matches with the AddToPollArray calls
// to release only the handles that were ref'd.
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);
Debug.Assert(refsAdded == 0);
}
return SocketError.Success;
}

private static unsafe void AddToPollArray(Interop.Sys.PollEvent* arr, int arrLength, IList socketList, ref int arrOffset, Interop.Sys.PollEvents events)
private static unsafe void AddToPollArray(Interop.Sys.PollEvent* arr, int arrLength, IList socketList, ref int arrOffset, Interop.Sys.PollEvents events, ref int refsAdded)
{
if (socketList == null)
return;
Expand All @@ -1494,12 +1515,15 @@ private static unsafe void AddToPollArray(Interop.Sys.PollEvent* arr, int arrLen
throw new ArgumentException(SR.Format(SR.net_sockets_select, socket?.GetType().FullName ?? "null", typeof(Socket).FullName), nameof(socketList));
}

bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
int fd = (int)socket.InternalSafeHandle.DangerousGetHandle();
arr[arrOffset++] = new Interop.Sys.PollEvent { Events = events, FileDescriptor = fd };
refsAdded++;
}
}

private static unsafe void FilterPollList(IList socketList, Interop.Sys.PollEvent* arr, int arrEndOffset, Interop.Sys.PollEvents desiredEvents)
private static unsafe void FilterPollList(IList socketList, Interop.Sys.PollEvent* arr, int arrEndOffset, Interop.Sys.PollEvents desiredEvents, ref int refsAdded)
{
if (socketList == null)
return;
Expand All @@ -1525,6 +1549,9 @@ private static unsafe void FilterPollList(IList socketList, Interop.Sys.PollEven

if ((arr[arrEndOffset].TriggeredEvents & desiredEvents) == 0)
{
Socket socket = (Socket)socketList[i];
socket.InternalSafeHandle.DangerousRelease();
refsAdded--;
socketList.RemoveAt(i);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -814,14 +814,17 @@ bool ShouldStackAlloc(IList list, ref IntPtr[] lease, out Span<IntPtr> span)
}

IntPtr[] leaseRead = null, leaseWrite = null, leaseError = null;
int refsAdded = 0;
try
{
// In case we can't increase the reference count for each Socket,
// we'll unref refAdded Sockets in the finally block ordered: [checkRead, checkWrite, checkError].
Span<IntPtr> readfileDescriptorSet = ShouldStackAlloc(checkRead, ref leaseRead, out var tmp) ? stackalloc IntPtr[StackThreshold] : tmp;
Socket.SocketListToFileDescriptorSet(checkRead, readfileDescriptorSet);
Socket.SocketListToFileDescriptorSet(checkRead, readfileDescriptorSet, ref refsAdded);
Span<IntPtr> writefileDescriptorSet = ShouldStackAlloc(checkWrite, ref leaseWrite, out tmp) ? stackalloc IntPtr[StackThreshold] : tmp;
Socket.SocketListToFileDescriptorSet(checkWrite, writefileDescriptorSet);
Socket.SocketListToFileDescriptorSet(checkWrite, writefileDescriptorSet, ref refsAdded);
Span<IntPtr> errfileDescriptorSet = ShouldStackAlloc(checkError, ref leaseError, out tmp) ? stackalloc IntPtr[StackThreshold] : tmp;
Socket.SocketListToFileDescriptorSet(checkError, errfileDescriptorSet);
Socket.SocketListToFileDescriptorSet(checkError, errfileDescriptorSet, ref refsAdded);

// This code used to erroneously pass a non-null timeval structure containing zeroes
// to select() when the caller specified (-1) for the microseconds parameter. That
Expand Down Expand Up @@ -872,9 +875,10 @@ bool ShouldStackAlloc(IList list, ref IntPtr[] lease, out Span<IntPtr> span)
return GetLastSocketError();
}

Socket.SelectFileDescriptor(checkRead, readfileDescriptorSet);
Socket.SelectFileDescriptor(checkWrite, writefileDescriptorSet);
Socket.SelectFileDescriptor(checkError, errfileDescriptorSet);
// Remove from the lists any entries which weren't set
Socket.SelectFileDescriptor(checkRead, readfileDescriptorSet, ref refsAdded);
Socket.SelectFileDescriptor(checkWrite, writefileDescriptorSet, ref refsAdded);
Socket.SelectFileDescriptor(checkError, errfileDescriptorSet, ref refsAdded);

return SocketError.Success;
}
Expand All @@ -883,6 +887,13 @@ bool ShouldStackAlloc(IList list, ref IntPtr[] lease, out Span<IntPtr> span)
if (leaseRead != null) ArrayPool<IntPtr>.Shared.Return(leaseRead);
if (leaseWrite != null) ArrayPool<IntPtr>.Shared.Return(leaseWrite);
if (leaseError != null) ArrayPool<IntPtr>.Shared.Return(leaseError);

// This order matches with the AddToPollArray calls
// to release only the handles that were ref'd.
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);
Debug.Assert(refsAdded == 0);
}
}

Expand Down

0 comments on commit 5773d1c

Please sign in to comment.