Skip to content

Commit

Permalink
Socket: don't disconnect Socket for unknown/unsupported socket option…
Browse files Browse the repository at this point in the history
…s. (dotnet#59925)

Fixes dotnet#59055.
  • Loading branch information
tmds authored Oct 8, 2021
1 parent 971479e commit edc5a41
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 19 deletions.
46 changes: 27 additions & 19 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1933,7 +1933,7 @@ public void SetSocketOption(SocketOptionLevel optionLevel, SocketOptionName opti
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}
}

Expand Down Expand Up @@ -2015,7 +2015,7 @@ public void SetRawSocketOption(int optionLevel, int optionName, ReadOnlySpan<byt

if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}
}

Expand Down Expand Up @@ -2051,7 +2051,7 @@ public void SetRawSocketOption(int optionLevel, int optionName, ReadOnlySpan<byt
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}

return optionValue;
Expand All @@ -2076,7 +2076,7 @@ public void GetSocketOption(SocketOptionLevel optionLevel, SocketOptionName opti
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}
}

Expand All @@ -2100,7 +2100,7 @@ public byte[] GetSocketOption(SocketOptionLevel optionLevel, SocketOptionName op
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}

if (optionLength != realOptionLength)
Expand Down Expand Up @@ -2136,7 +2136,7 @@ public int GetRawSocketOption(int optionLevel, int optionName, Span<byte> option

if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}

return realOptionLength;
Expand Down Expand Up @@ -3432,7 +3432,7 @@ internal unsafe void SetSocketOption(SocketOptionLevel optionLevel, SocketOption
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}
}

Expand All @@ -3445,7 +3445,7 @@ private void SetMulticastOption(SocketOptionName optionName, MulticastOption MR)
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}
}

Expand All @@ -3472,7 +3472,7 @@ private void SetLingerOption(LingerOption lref)
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}
}

Expand All @@ -3486,7 +3486,7 @@ private void SetLingerOption(LingerOption lref)
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}

return lingerOption;
Expand All @@ -3502,7 +3502,7 @@ private void SetLingerOption(LingerOption lref)
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}

return multicastOption;
Expand All @@ -3519,7 +3519,7 @@ private void SetLingerOption(LingerOption lref)
// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
UpdateStatusAfterSocketOptionErrorAndThrowException(errorCode);
}

return multicastOption;
Expand Down Expand Up @@ -3690,30 +3690,38 @@ internal void SetToDisconnected()
}
}

private void UpdateStatusAfterSocketErrorAndThrowException(SocketError error, [CallerMemberName] string? callerName = null)
private void UpdateStatusAfterSocketOptionErrorAndThrowException(SocketError error, [CallerMemberName] string? callerName = null)
{
// Don't disconnect socket for unknown options.
bool disconnectOnFailure = error != SocketError.ProtocolOption &&
error != SocketError.OperationNotSupported;
UpdateStatusAfterSocketErrorAndThrowException(error, disconnectOnFailure, callerName);
}

private void UpdateStatusAfterSocketErrorAndThrowException(SocketError error, bool disconnectOnFailure = true, [CallerMemberName] string? callerName = null)
{
// Update the internal state of this socket according to the error before throwing.
var socketException = new SocketException((int)error);
UpdateStatusAfterSocketError(socketException);
UpdateStatusAfterSocketError(socketException, disconnectOnFailure);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, socketException, memberName: callerName);
throw socketException;
}

// UpdateStatusAfterSocketError(socketException) - updates the status of a connected socket
// on which a failure occurred. it'll go to winsock and check if the connection
// is still open and if it needs to update our internal state.
internal void UpdateStatusAfterSocketError(SocketException socketException)
internal void UpdateStatusAfterSocketError(SocketException socketException, bool disconnectOnFailure = true)
{
UpdateStatusAfterSocketError(socketException.SocketErrorCode);
UpdateStatusAfterSocketError(socketException.SocketErrorCode, disconnectOnFailure);
}

internal void UpdateStatusAfterSocketError(SocketError errorCode)
internal void UpdateStatusAfterSocketError(SocketError errorCode, bool disconnectOnFailure = true)
{
// If we already know the socket is disconnected
// we don't need to do anything else.
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, $"errorCode:{errorCode}");
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, $"errorCode:{errorCode}, disconnectOnFailure:{disconnectOnFailure}");

if (_isConnected && (_handle.IsInvalid || (errorCode != SocketError.WouldBlock &&
if (disconnectOnFailure && _isConnected && (_handle.IsInvalid || (errorCode != SocketError.WouldBlock &&
errorCode != SocketError.IOPending && errorCode != SocketError.NoBufferSpaceAvailable &&
errorCode != SocketError.TimedOut)))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,132 @@ public void Get_AcceptConnection_Succeeds()
}
}

[Fact]
public void GetUnsupportedSocketOption_DoesNotDisconnectSocket()
{
(Socket socket1, Socket socket2) = SocketTestExtensions.CreateConnectedSocketPair();
using (socket1)
using (socket2)
{
SocketException se = Assert.Throws<SocketException>(() => socket1.GetSocketOption(SocketOptionLevel.Socket, (SocketOptionName)(-1)));
Assert.True(se.SocketErrorCode == SocketError.ProtocolOption ||
se.SocketErrorCode == SocketError.OperationNotSupported, $"SocketError: {se.SocketErrorCode}");

Assert.True(socket1.Connected, "Connected");
}
}

[Fact]
public void GetUnsupportedSocketOptionBytesArg_DoesNotDisconnectSocket()
{
(Socket socket1, Socket socket2) = SocketTestExtensions.CreateConnectedSocketPair();
using (socket1)
using (socket2)
{
var optionValue = new byte[4];
SocketException se = Assert.Throws<SocketException>(() => socket1.GetSocketOption(SocketOptionLevel.Socket, (SocketOptionName)(-1), optionValue));
Assert.True(se.SocketErrorCode == SocketError.ProtocolOption ||
se.SocketErrorCode == SocketError.OperationNotSupported, $"SocketError: {se.SocketErrorCode}");

Assert.True(socket1.Connected, "Connected");
}
}

[Fact]
public void GetUnsupportedSocketOptionLengthArg_DoesNotDisconnectSocket()
{
(Socket socket1, Socket socket2) = SocketTestExtensions.CreateConnectedSocketPair();
using (socket1)
using (socket2)
{
SocketException se = Assert.Throws<SocketException>(() => socket1.GetSocketOption(SocketOptionLevel.Socket, (SocketOptionName)(-1), optionLength: 4));
Assert.True(se.SocketErrorCode == SocketError.ProtocolOption ||
se.SocketErrorCode == SocketError.OperationNotSupported, $"SocketError: {se.SocketErrorCode}");

Assert.True(socket1.Connected, "Connected");
}
}

[Fact]
public void SetUnsupportedSocketOptionIntArg_DoesNotDisconnectSocket()
{
(Socket socket1, Socket socket2) = SocketTestExtensions.CreateConnectedSocketPair();
using (socket1)
using (socket2)
{
SocketException se = Assert.Throws<SocketException>(() => socket1.SetSocketOption(SocketOptionLevel.Socket, (SocketOptionName)(-1), optionValue: 1));
Assert.True(se.SocketErrorCode == SocketError.ProtocolOption ||
se.SocketErrorCode == SocketError.OperationNotSupported, $"SocketError: {se.SocketErrorCode}");

Assert.True(socket1.Connected, "Connected");
}
}

[Fact]
public void SetUnsupportedSocketOptionBytesArg_DoesNotDisconnectSocket()
{
(Socket socket1, Socket socket2) = SocketTestExtensions.CreateConnectedSocketPair();
using (socket1)
using (socket2)
{
var optionValue = new byte[4];
SocketException se = Assert.Throws<SocketException>(() => socket1.SetSocketOption(SocketOptionLevel.Socket, (SocketOptionName)(-1), optionValue));
Assert.True(se.SocketErrorCode == SocketError.ProtocolOption ||
se.SocketErrorCode == SocketError.OperationNotSupported, $"SocketError: {se.SocketErrorCode}");

Assert.True(socket1.Connected, "Connected");
}
}

[Fact]
public void SetUnsupportedSocketOptionBoolArg_DoesNotDisconnectSocket()
{
(Socket socket1, Socket socket2) = SocketTestExtensions.CreateConnectedSocketPair();
using (socket1)
using (socket2)
{
bool optionValue = true;
SocketException se = Assert.Throws<SocketException>(() => socket1.SetSocketOption(SocketOptionLevel.Socket, (SocketOptionName)(-1), optionValue));
Assert.True(se.SocketErrorCode == SocketError.ProtocolOption ||
se.SocketErrorCode == SocketError.OperationNotSupported, $"SocketError: {se.SocketErrorCode}");

Assert.True(socket1.Connected, "Connected");
}
}

[Fact]
public void GetUnsupportedRawSocketOption_DoesNotDisconnectSocket()
{
(Socket socket1, Socket socket2) = SocketTestExtensions.CreateConnectedSocketPair();
using (socket1)
using (socket2)
{
var optionValue = new byte[4];
SocketException se = Assert.Throws<SocketException>(() => socket1.GetRawSocketOption(SOL_SOCKET, -1, optionValue));
Assert.True(se.SocketErrorCode == SocketError.ProtocolOption ||
se.SocketErrorCode == SocketError.OperationNotSupported, $"SocketError: {se.SocketErrorCode}");

Assert.True(socket1.Connected, "Connected");
}
}

[Fact]
public void SetUnsupportedRawSocketOption_DoesNotDisconnectSocket()
{
(Socket socket1, Socket socket2) = SocketTestExtensions.CreateConnectedSocketPair();
using (socket1)
using (socket2)
{
var optionValue = new byte[4];
SocketException se = Assert.Throws<SocketException>(() => socket1.SetRawSocketOption(SOL_SOCKET, -1, optionValue));
Assert.True(se.SocketErrorCode == SocketError.ProtocolOption ||
se.SocketErrorCode == SocketError.OperationNotSupported, $"SocketError: {se.SocketErrorCode}");

Assert.True(socket1.Connected, "Connected");
}
}

private static int SOL_SOCKET = OperatingSystem.IsLinux() ? 1 : (int)SocketOptionLevel.Socket;
}

[Collection("NoParallelTests")]
Expand Down

0 comments on commit edc5a41

Please sign in to comment.