Skip to content

Commit

Permalink
Fix handling of new connection in MsQuicListener (dotnet#57319)
Browse files Browse the repository at this point in the history
* update

* feedback from review
  • Loading branch information
wfurt authored Aug 17, 2021
1 parent 24e9212 commit 09ba220
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ internal sealed class State

// These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown).
public MsQuicConnection? Connection;
public MsQuicListener.State? ListenerState;

public TaskCompletionSource<uint>? ConnectTcs;
// TODO: only allocate these when there is an outstanding shutdown.
Expand Down Expand Up @@ -135,11 +136,10 @@ public void SetClosing()
internal string TraceId() => _state.TraceId;

// constructor for inbound connections
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, MsQuicListener.State listenerState, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
{
_state.Handle = handle;
_state.StateGCHandle = GCHandle.Alloc(_state);
_state.Connected = true;
_state.RemoteCertificateRequired = remoteCertificateRequired;
_state.RevocationMode = revocationMode;
_state.RemoteCertificateValidationCallback = remoteCertificateValidationCallback;
Expand All @@ -161,6 +161,7 @@ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, Saf
throw;
}

_state.ListenerState = listenerState;
_state.TraceId = MsQuicTraceHelper.GetTraceId(_state.Handle);
if (NetEventSource.Log.IsEnabled())
{
Expand Down Expand Up @@ -223,7 +224,34 @@ public MsQuicConnection(QuicClientConnectionOptions options)

private static uint HandleEventConnected(State state, ref ConnectionEvent connectionEvent)
{
if (!state.Connected)
if (state.Connected)
{
return MsQuicStatusCodes.Success;
}

if (state.IsServer)
{
state.Connected = true;
MsQuicListener.State? listenerState = state.ListenerState;
state.ListenerState = null;

if (listenerState != null)
{
if (listenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection))
{
// Move connection from pending to Accept queue and hand it out.
if (listenerState.AcceptConnectionQueue.Writer.TryWrite(connection))
{
return MsQuicStatusCodes.Success;
}
// Listener is closed
connection.Dispose();
}
}

return MsQuicStatusCodes.UserCanceled;
}
else
{
// Connected will already be true for connections accepted from a listener.
Debug.Assert(!Monitor.IsEntered(state));
Expand Down Expand Up @@ -271,6 +299,18 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent
// This is the final event on the connection, so free the GCHandle used by the event callback.
state.StateGCHandle.Free();

if (state.ListenerState != null)
{
// This is inbound connection that never got connected - becasue of TLS validation or some other reason.
// Remove connection from pending queue and dispose it.
if (state.ListenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection))
{
connection.Dispose();
}

state.ListenerState = null;
}

state.Connection = null;

state.ShutdownTcs.SetResult(MsQuicStatusCodes.Success);
Expand All @@ -297,6 +337,7 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent
{
bidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
}

return MsQuicStatusCodes.Success;
}

Expand Down Expand Up @@ -418,6 +459,11 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti

if (!success)
{
if (state.IsServer)
{
return MsQuicStatusCodes.UserCanceled;
}

throw new AuthenticationException(SR.net_quic_cert_custom_validation);
}

Expand All @@ -430,6 +476,11 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti

if (sslPolicyErrors != SslPolicyErrors.None)
{
if (state.IsServer)
{
return MsQuicStatusCodes.HandshakeFailure;
}

throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Buffers;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Net.Quic.Implementations.MsQuic.Internal;
using System.Net.Security;
Expand All @@ -25,14 +26,15 @@ internal sealed class MsQuicListener : QuicListenerProvider, IDisposable

private readonly IPEndPoint _listenEndPoint;

private sealed class State
internal sealed class State
{
// set immediately in ctor, but we need a GCHandle to State in order to create the handle.
public SafeMsQuicListenerHandle Handle = null!;
public string TraceId = null!; // set in ctor.

public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration;
public readonly Channel<MsQuicConnection> AcceptConnectionQueue;
public readonly ConcurrentDictionary<IntPtr, MsQuicConnection> PendingConnections;

public QuicOptions ConnectionOptions = new QuicOptions();
public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions();
Expand Down Expand Up @@ -66,6 +68,7 @@ public State(QuicListenerOptions options)
ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions);
}

PendingConnections = new ConcurrentDictionary<IntPtr, MsQuicConnection>();
AcceptConnectionQueue = Channel.CreateBounded<MsQuicConnection>(new BoundedChannelOptions(options.ListenBacklog)
{
SingleReader = true,
Expand Down Expand Up @@ -234,7 +237,6 @@ private static unsafe uint NativeCallbackHandler(

SafeMsQuicConnectionHandle? connectionHandle = null;
MsQuicConnection? msQuicConnection = null;

try
{
ref NewConnectionInfo connectionInfo = ref *evt.Data.NewConnection.Info;
Expand Down Expand Up @@ -278,13 +280,15 @@ private static unsafe uint NativeCallbackHandler(
uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration);
if (MsQuicStatusHelper.SuccessfulStatusCode(status))
{
msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, state, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);

if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
if (!state.PendingConnections.TryAdd(connectionHandle.DangerousGetHandle(), msQuicConnection))
{
return MsQuicStatusCodes.Success;
msQuicConnection.Dispose();
}

return MsQuicStatusCodes.Success;
}

// If we fall-through here something wrong happened.
Expand Down
70 changes: 61 additions & 9 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,48 @@ public async Task ConnectWithCertificateChain()
clientConnection.Dispose();
}

[Fact]
[PlatformSpecific(TestPlatforms.Windows)]
public async Task UntrustedClientCertificateFails()
{
var listenerOptions = new QuicListenerOptions();
listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
return false;
};

using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
clientOptions.RemoteEndPoint = listener.ListenEndPoint;
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
QuicConnection clientConnection = CreateQuicConnection(clientOptions);

using CancellationTokenSource cts = new CancellationTokenSource();
cts.CancelAfter(500); //Some delay to see if we would get failed connection.
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();

ValueTask t = clientConnection.ConnectAsync(cts.Token);

t.AsTask().Wait(PassingTestTimeout);
await Assert.ThrowsAsync<OperationCanceledException>(() => serverTask);
// The task will likely succed but we don't really care.
// It may fail if the server aborts quickly.
try
{
await t;
}
catch { };
}

[Fact]
public async Task CertificateCallbackThrowPropagates()
{
using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout);
X509Certificate? receivedCertificate = null;
bool validationResult = false;

var listenerOptions = new QuicListenerOptions();
listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
Expand All @@ -118,18 +155,26 @@ public async Task CertificateCallbackThrowPropagates()
clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
receivedCertificate = cert;
if (validationResult)
{
return validationResult;
}

throw new ArithmeticException("foobar");
};

clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1";
QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);

Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
await Assert.ThrowsAsync<ArithmeticException>(() => clientConnection.ConnectAsync(cts.Token).AsTask());
QuicConnection serverConnection = await serverTask;

Assert.Equal(listenerOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate);
clientConnection.Dispose();

// Make sure the listner is still usable and there is no lingering bad conenction
validationResult = true;
(clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener);
await PingPong(clientConnection, serverConnection);
clientConnection.Dispose();
serverConnection.Dispose();
}
Expand Down Expand Up @@ -253,7 +298,6 @@ public async Task ConnectWithCertificateForDifferentName_Throws()
using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
ValueTask clientTask = clientConnection.ConnectAsync();

using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await Assert.ThrowsAsync<AuthenticationException>(async () => await clientTask);
}

Expand Down Expand Up @@ -284,9 +328,11 @@ public async Task ConnectWithCertificateForLoopbackIP_IndicatesExpectedError(str
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions);
}

[Fact]
[Theory]
[PlatformSpecific(TestPlatforms.Windows)]
public async Task ConnectWithClientCertificate()
[InlineData(true)]
// [InlineData(false)] [ActiveIssue("https://github.com/dotnet/runtime/issues/57308")]
public async Task ConnectWithClientCertificate(bool sendCerttificate)
{
bool clientCertificateOK = false;

Expand All @@ -296,17 +342,23 @@ public async Task ConnectWithClientCertificate()
listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
_output.WriteLine("client certificate {0}", cert);
Assert.NotNull(cert);
Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);
if (sendCerttificate)
{
_output.WriteLine("client certificate {0}", cert);
Assert.NotNull(cert);
Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);
}

clientCertificateOK = true;
return true;
};

using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
if (sendCerttificate)
{
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
}
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener);

// Verify functionality of the connections.
Expand Down

0 comments on commit 09ba220

Please sign in to comment.