Skip to content

Commit

Permalink
SmtpClient: fix Timeout for low values (dotnet/corefx#37462)
Browse files Browse the repository at this point in the history
* SmtpClient: fix Timeout for low values

The timeout is triggered from a Timer.
There were two race conditions:
- The timer event could occur before there was a connection to abort.
- The connection abort performs a TcpClient.Dispose, which may not
work when there is a TcpClient.Connect occurring simultaneously.

* TCPClient: make Dispose thread-safe

* Fix NullReferenceException from _networkstream.Close

* SmtpTransport: fix racy aborted detection

* TestZeroTimeout: skip on .NET Framework

* PR feedback

* Don't null out socket on Dispose to let ODE propagate from Socket

* Fix merge

* Run TestZeroTimeout test on Linux too

* Add back ODE to TcpClient.Connect(IPEndPoint)

* Add comment about why we're disposing the socket

* PR feedback


Commit migrated from dotnet/corefx@807a18d
  • Loading branch information
tmds authored and stephentoub committed Aug 29, 2019
1 parent c5d6605 commit 1bcdf38
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class SmtpClient : IDisposable
{
private string _host;
private int _port;
private int _timeout = 100000;
private bool _inCall;
private bool _cancelled;
private bool _timedOut;
Expand Down Expand Up @@ -262,7 +263,7 @@ public int Timeout
{
get
{
return _transport.Timeout;
return _timeout;
}
set
{
Expand All @@ -276,7 +277,7 @@ public int Timeout
throw new ArgumentOutOfRangeException(nameof(value));
}

_transport.Timeout = value;
_timeout = value;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ internal partial class SmtpConnection
private SmtpReplyReaderFactory _responseReader;

private readonly ICredentialsByHost _credentials;
private int _timeout = 100000;
private string[] _extensions;
private readonly ChannelBinding _channelBindingToken = null;
private bool _enableSsl;
Expand Down Expand Up @@ -72,19 +71,6 @@ internal bool EnableSsl
}
}

internal int Timeout
{
get
{
return _timeout;
}
set
{
_timeout = value;
}
}


internal X509CertificateCollection ClientCertificates
{
get
Expand Down
44 changes: 21 additions & 23 deletions src/libraries/System.Net.Mail/src/System/Net/Mail/SmtpTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Net.Mime;
using System.Runtime.ExceptionServices;
using System.Security.Cryptography.X509Certificates;
using System.Threading;

namespace System.Net.Mail
{
Expand All @@ -18,9 +19,9 @@ internal class SmtpTransport
private SmtpConnection _connection;
private readonly SmtpClient _client;
private ICredentialsByHost _credentials;
private int _timeout = 100000; // seconds
private readonly List<SmtpFailedRecipientException> _failedRecipientExceptions = new List<SmtpFailedRecipientException>();
private bool _identityRequired;
private bool _shouldAbort;

private bool _enableSsl = false;
private X509CertificateCollection _clientCertificates = null;
Expand Down Expand Up @@ -74,23 +75,6 @@ internal bool IsConnected
}
}

internal int Timeout
{
get
{
return _timeout;
}
set
{
if (value < 0)
{
throw new ArgumentOutOfRangeException(nameof(value));
}

_timeout = value;
}
}

internal bool EnableSsl
{
get
Expand Down Expand Up @@ -124,8 +108,16 @@ internal void GetConnection(string host, int port)
{
try
{
_connection = new SmtpConnection(this, _client, _credentials, _authenticationModules);
_connection.Timeout = _timeout;
lock (this)
{
_connection = new SmtpConnection(this, _client, _credentials, _authenticationModules);
if (_shouldAbort)
{
_connection.Abort();
}
_shouldAbort = false;
}

if (NetEventSource.IsEnabled) NetEventSource.Associate(this, _connection);

if (EnableSsl)
Expand All @@ -146,7 +138,6 @@ internal IAsyncResult BeginGetConnection(ContextAwareResult outerResult, AsyncCa
try
{
_connection = new SmtpConnection(this, _client, _credentials, _authenticationModules);
_connection.Timeout = _timeout;
if (NetEventSource.IsEnabled) NetEventSource.Associate(this, _connection);
if (EnableSsl)
{
Expand Down Expand Up @@ -212,9 +203,16 @@ internal void ReleaseConnection()

internal void Abort()
{
if (_connection != null)
lock (this)
{
_connection.Abort();
if (_connection != null)
{
_connection.Abort();
}
else
{
_shouldAbort = true;
}
}
}

Expand Down
20 changes: 20 additions & 0 deletions src/libraries/System.Net.Mail/tests/Functional/SmtpClientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Collections.Generic;
using System.IO;
using System.Net.NetworkInformation;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -315,6 +316,25 @@ public void TestMailDelivery()
}
}

[Fact]
[SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, ".NET Framework has a bug and may not time out for low values")]
[PlatformSpecific(~TestPlatforms.OSX)] // on OSX, not all synchronous operations (e.g. connect) can be aborted by closing the socket.
public void TestZeroTimeout()
{
using (Socket serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
serverSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
serverSocket.Listen(1);

SmtpClient smtpClient = new SmtpClient("localhost", (serverSocket.LocalEndPoint as IPEndPoint).Port);
smtpClient.Timeout = 0;

MailMessage msg = new MailMessage("[email protected]", "[email protected]", "hello", "test");
Assert.Throws<SmtpException>(() => smtpClient.Send(msg));
}
}

[SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, ".NET Framework has a bug and could hang in case of null or empty body")]
[Theory]
[InlineData("howdydoo")]
[InlineData("")]
Expand Down
111 changes: 55 additions & 56 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.Diagnostics;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.Sockets
Expand All @@ -16,9 +17,11 @@ public class TcpClient : IDisposable
private AddressFamily _family;
private Socket _clientSocket;
private NetworkStream _dataStream;
private bool _cleanedUp;
private volatile int _cleanedUp;
private bool _active;

private bool CleanedUp => _cleanedUp == 1;

// Initializes a new instance of the System.Net.Sockets.TcpClient class.
public TcpClient() : this(AddressFamily.Unknown)
{
Expand Down Expand Up @@ -106,7 +109,7 @@ protected bool Active
set { _active = value; }
}

public int Available => _clientSocket?.Available ?? 0;
public int Available => CleanedUp ? 0 : _clientSocket.Available;

// Used by the class to provide the underlying network socket.
public Socket Client
Expand All @@ -115,15 +118,15 @@ public Socket Client
set
{
_clientSocket = value;
_family = _clientSocket?.AddressFamily ?? AddressFamily.Unknown;
_family = CleanedUp ? AddressFamily.Unknown : _clientSocket.AddressFamily;
}
}

public bool Connected => _clientSocket?.Connected ?? false;
public bool Connected => CleanedUp ? false : _clientSocket.Connected;

public bool ExclusiveAddressUse
{
get { return _clientSocket?.ExclusiveAddressUse ?? false; }
get { return CleanedUp ? false : _clientSocket.ExclusiveAddressUse; }
set
{
if (_clientSocket != null)
Expand All @@ -138,7 +141,7 @@ public void Connect(string hostname, int port)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, hostname);

if (_cleanedUp)
if (CleanedUp)
{
throw new ObjectDisposedException(GetType().FullName);
}
Expand Down Expand Up @@ -171,7 +174,6 @@ public void Connect(string hostname, int port)
{
foreach (IPAddress address in addresses)
{
Socket tmpSocket = null;
try
{
if (_clientSocket == null)
Expand All @@ -181,10 +183,24 @@ public void Connect(string hostname, int port)
Debug.Assert(address.AddressFamily == AddressFamily.InterNetwork || address.AddressFamily == AddressFamily.InterNetworkV6);
if ((address.AddressFamily == AddressFamily.InterNetwork && Socket.OSSupportsIPv4) || Socket.OSSupportsIPv6)
{
tmpSocket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
tmpSocket.Connect(address, port);
_clientSocket = tmpSocket;
tmpSocket = null;
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
// Use of Interlocked.Exchanged ensures _clientSocket is written before CleanedUp is read.
Interlocked.Exchange(ref _clientSocket, socket);
if (CleanedUp)
{
// Dispose the socket so it throws ObjectDisposedException when we Connect.
socket.Dispose();
}

try
{
socket.Connect(address, port);
}
catch
{
_clientSocket = null;
throw;
}
}

_family = address.AddressFamily;
Expand All @@ -201,11 +217,6 @@ public void Connect(string hostname, int port)
}
catch (Exception ex) when (!(ex is OutOfMemoryException))
{
if (tmpSocket != null)
{
tmpSocket.Dispose();
tmpSocket = null;
}
lastex = ExceptionDispatchInfo.Capture(ex);
}
}
Expand All @@ -228,7 +239,7 @@ public void Connect(IPAddress address, int port)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, address);

if (_cleanedUp)
if (CleanedUp)
{
throw new ObjectDisposedException(GetType().FullName);
}
Expand All @@ -252,7 +263,7 @@ public void Connect(IPEndPoint remoteEP)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, remoteEP);

if (_cleanedUp)
if (CleanedUp)
{
throw new ObjectDisposedException(GetType().FullName);
}
Expand Down Expand Up @@ -331,14 +342,7 @@ public void EndConnect(IAsyncResult asyncResult)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, asyncResult);

Socket s = Client;
if (s == null)
{
// Dispose nulls out the client socket field.
throw new ObjectDisposedException(GetType().Name);
}

s.EndConnect(asyncResult);
Client.EndConnect(asyncResult);
_active = true;

if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
Expand All @@ -349,7 +353,7 @@ public NetworkStream GetStream()
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this);

if (_cleanedUp)
if (CleanedUp)
{
throw new ObjectDisposedException(GetType().FullName);
}
Expand All @@ -374,44 +378,39 @@ protected virtual void Dispose(bool disposing)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this);

if (_cleanedUp)
if (Interlocked.CompareExchange(ref _cleanedUp, 1, 0) == 0)
{
if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
return;
}

if (disposing)
{
IDisposable dataStream = _dataStream;
if (dataStream != null)
{
dataStream.Dispose();
}
else
if (disposing)
{
// If the NetworkStream wasn't created, the Socket might
// still be there and needs to be closed. In the case in which
// we are bound to a local IPEndPoint this will remove the
// binding and free up the IPEndPoint for later uses.
Socket chkClientSocket = _clientSocket;
if (chkClientSocket != null)
IDisposable dataStream = _dataStream;
if (dataStream != null)
{
try
{
chkClientSocket.InternalShutdown(SocketShutdown.Both);
}
finally
dataStream.Dispose();
}
else
{
// If the NetworkStream wasn't created, the Socket might
// still be there and needs to be closed. In the case in which
// we are bound to a local IPEndPoint this will remove the
// binding and free up the IPEndPoint for later uses.
Socket chkClientSocket = Volatile.Read(ref _clientSocket);
if (chkClientSocket != null)
{
chkClientSocket.Close();
_clientSocket = null;
try
{
chkClientSocket.InternalShutdown(SocketShutdown.Both);
}
finally
{
chkClientSocket.Close();
}
}
}
}

GC.SuppressFinalize(this);
GC.SuppressFinalize(this);
}
}

_cleanedUp = true;
if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
}

Expand Down

0 comments on commit 1bcdf38

Please sign in to comment.