Skip to content

Commit

Permalink
propagate TLS alerts from OS layers (dotnet/corefx#41967)
Browse files Browse the repository at this point in the history
* initial alerts with openssl

* get alerts from schannel

* update tests to work with openssl 1.1.x

* fix ClientAsyncAuthenticate_ServerNoEncryption_NoConnect to work properly with Tls13

* remove extra comment

* feedback from review

* feedback from review

* remove unused variable


Commit migrated from dotnet/corefx@784cb6b
  • Loading branch information
wfurt authored Nov 13, 2019
1 parent ca73ff2 commit 35c27fb
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int r
{
sendBuf = null;
sendCount = 0;
Exception handshakeException = null;

if ((recvBuf != null) && (recvCount > 0))
{
Expand All @@ -275,7 +276,10 @@ internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int r

if ((retVal != -1) || (error != Ssl.SslErrorCode.SSL_ERROR_WANT_READ))
{
throw new SslException(SR.Format(SR.net_ssl_handshake_failed_error, error), innerError);
// Handshake failed, but even if the handshake does not need to read, there may be an Alert going out.
// To handle that we will fall-through the block below to pull it out, and we will fail after.
handshakeException = new SslException(SR.Format(SR.net_ssl_handshake_failed_error, error), innerError);
Crypto.ErrClearError();
}
}

Expand All @@ -288,6 +292,10 @@ internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int r
{
sendCount = BioRead(context.OutputBio, sendBuf, sendCount);
}
catch (Exception) when (handshakeException != null)
{
// If we already have handshake exception, ignore any exception from BioRead().
}
finally
{
if (sendCount <= 0)
Expand All @@ -300,6 +308,11 @@ internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int r
}
}

if (handshakeException != null)
{
throw handshakeException;
}

bool stateOk = Ssl.IsSslStateOK(context);
if (stateOk)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ internal static unsafe int AcceptSecurityContext(
}

Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Length);
Interop.SspiCli.SecBufferDesc outSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(1);
Interop.SspiCli.SecBufferDesc outSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(count: 2);

// Actually, this is returned in outFlags.
bool isSspiAllocated = (inFlags & Interop.SspiCli.ContextFlags.AllocateMemory) != 0 ? true : false;
Expand All @@ -659,12 +659,15 @@ internal static unsafe int AcceptSecurityContext(

// Optional output buffer that may need to be freed.
SafeFreeContextBuffer outFreeContextBuffer = null;
Span<Interop.SspiCli.SecBuffer> outUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[2];
outUnmanagedBuffer[1].pvBuffer = IntPtr.Zero;
try
{
Span<Interop.SspiCli.SecBuffer> inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[inSecurityBufferDescriptor.cBuffers];
inUnmanagedBuffer.Clear();

fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer)
fixed (void* outUnmanagedBufferPtr = outUnmanagedBuffer)
fixed (void* pinnedToken0 = inSecBuffers.Length > 0 ? inSecBuffers[0].token : null)
fixed (void* pinnedToken1 = inSecBuffers.Length > 1 ? inSecBuffers[1].token : null)
fixed (void* pinnedToken2 = inSecBuffers.Length > 2 ? inSecBuffers[2].token : null) // pin all buffers, even if null or not used, to avoid needing to allocate GCHandles
Expand Down Expand Up @@ -694,16 +697,18 @@ internal static unsafe int AcceptSecurityContext(
fixed (byte* pinnedOutBytes = outSecBuffer.token)
{
// Fix Descriptor pointer that points to unmanaged SecurityBuffers.
Interop.SspiCli.SecBuffer outUnmanagedBuffer = default;
outSecurityBufferDescriptor.pBuffers = &outUnmanagedBuffer;
outSecurityBufferDescriptor.pBuffers = outUnmanagedBufferPtr;

// Copy the SecurityBuffer content into unmanaged place holder.
outUnmanagedBuffer.cbBuffer = outSecBuffer.size;
outUnmanagedBuffer.BufferType = outSecBuffer.type;
outUnmanagedBuffer.pvBuffer = outSecBuffer.token == null || outSecBuffer.token.Length == 0 ?
outUnmanagedBuffer[0].cbBuffer = outSecBuffer.size;
outUnmanagedBuffer[0].BufferType = outSecBuffer.type;
outUnmanagedBuffer[0].pvBuffer = outSecBuffer.token == null || outSecBuffer.token.Length == 0 ?
IntPtr.Zero :
(IntPtr)(pinnedOutBytes + outSecBuffer.offset);

outUnmanagedBuffer[1].cbBuffer = 0;
outUnmanagedBuffer[1].BufferType = SecurityBufferType.SECBUFFER_ALERT;

if (isSspiAllocated)
{
outFreeContextBuffer = SafeFreeContextBuffer.CreateEmptyHandle();
Expand Down Expand Up @@ -731,18 +736,31 @@ internal static unsafe int AcceptSecurityContext(

if (NetEventSource.IsEnabled) NetEventSource.Info(null, "Marshaling OUT buffer");

// Get unmanaged buffer with index 0 as the only one passed into PInvoke.
outSecBuffer.size = outUnmanagedBuffer.cbBuffer;
outSecBuffer.type = outUnmanagedBuffer.BufferType;
outSecBuffer.token = outUnmanagedBuffer.cbBuffer > 0 ?
new Span<byte>((byte*)outUnmanagedBuffer.pvBuffer, outUnmanagedBuffer.cbBuffer).ToArray() :
null;
// No data written out but there is Alert
if (outUnmanagedBuffer[0].cbBuffer == 0 && outUnmanagedBuffer[1].cbBuffer > 0)
{
outSecBuffer.size = outUnmanagedBuffer[1].cbBuffer;
outSecBuffer.type = outUnmanagedBuffer[1].BufferType;
outSecBuffer.token = new Span<byte>((byte*)outUnmanagedBuffer[1].pvBuffer, outUnmanagedBuffer[1].cbBuffer).ToArray();
}
else
{
outSecBuffer.size = outUnmanagedBuffer[0].cbBuffer;
outSecBuffer.type = outUnmanagedBuffer[0].BufferType;
outSecBuffer.token = outUnmanagedBuffer[0].cbBuffer > 0 ?
new Span<byte>((byte*)outUnmanagedBuffer[0].pvBuffer, outUnmanagedBuffer[0].cbBuffer).ToArray() :
null;
}
}
}
}
finally
{
outFreeContextBuffer?.Dispose();
if (outUnmanagedBuffer[1].pvBuffer != IntPtr.Zero)
{
Interop.SspiCli.FreeContextBuffer(outUnmanagedBuffer[1].pvBuffer);
}
}

if (NetEventSource.IsEnabled) NetEventSource.Exit(null, $"errorCode:0x{errorCode:x8}, refContext:{refContext}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,16 @@ private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credentia
{
Debug.Assert(!credential.IsInvalid);

byte[] output = null;
int outputSize = 0;

try
{
if ((null == context) || context.IsInvalid)
{
context = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, sslAuthenticationOptions);
}

byte[] output = null;
int outputSize;
bool done;

if (inputBuffer.Array == null)
Expand Down Expand Up @@ -143,6 +144,12 @@ private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credentia
}
catch (Exception exc)
{
// Even if handshake failed we may have Alert to sent.
if (outputSize > 0)
{
outputBuffer = outputSize == output.Length ? output : new Span<byte>(output, 0, outputSize).ToArray();
}

return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, exc);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ internal static class SslStreamPal
Interop.SspiCli.ContextFlags.AllocateMemory;

private const Interop.SspiCli.ContextFlags ServerRequiredFlags =
RequiredFlags | Interop.SspiCli.ContextFlags.AcceptStream;
RequiredFlags | Interop.SspiCli.ContextFlags.AcceptStream | Interop.SspiCli.ContextFlags.AcceptExtendedError;

public static Exception GetException(SecurityStatusPal status)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ public async Task ClientAsyncAuthenticate_ServerRequireEncryption_ConnectWithEnc
[Fact]
public async Task ClientAsyncAuthenticate_ServerNoEncryption_NoConnect()
{
await Assert.ThrowsAsync<IOException>(() => ClientAsyncSslHelper(EncryptionPolicy.NoEncryption));
// Don't use Tls13 since we are trying to use NullEncryption
Type expectedExceptionType = TestConfiguration.SupportsHandshakeAlerts && TestConfiguration.SupportsNullEncryption ?
typeof(AuthenticationException) :
typeof(IOException);

await Assert.ThrowsAsync(expectedExceptionType,
() => ClientAsyncSslHelper(
EncryptionPolicy.NoEncryption,
SslProtocolSupport.DefaultSslProtocols, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 ));
}

[Theory]
Expand Down Expand Up @@ -112,12 +120,12 @@ public static IEnumerable<object[]> ProtocolMismatchData()
yield return new object[] { SslProtocols.Ssl2, SslProtocols.Tls12, typeof(Exception) };
yield return new object[] { SslProtocols.Ssl3, SslProtocols.Tls12, typeof(Exception) };
#pragma warning restore 0618
yield return new object[] { SslProtocols.Tls, SslProtocols.Tls11, typeof(IOException) };
yield return new object[] { SslProtocols.Tls, SslProtocols.Tls12, typeof(IOException) };
yield return new object[] { SslProtocols.Tls, SslProtocols.Tls11, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(IOException) };
yield return new object[] { SslProtocols.Tls, SslProtocols.Tls12, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(IOException) };
yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls, typeof(AuthenticationException) };
yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls, typeof(AuthenticationException) };
yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls11, typeof(AuthenticationException) };
yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls12, typeof(IOException) };
yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls12, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(IOException) };
}

#region Helpers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public async Task ClientDefaultEncryption_ServerNoEncryption_NoConnect()

using (var sslStream = new SslStream(client.GetStream(), false, AllowAnyServerCertificate, null))
{
await Assert.ThrowsAsync<IOException>(() =>
await Assert.ThrowsAsync(TestConfiguration.SupportsHandshakeAlerts ? typeof(AuthenticationException) : typeof(IOException), () =>
sslStream.AuthenticateAsClientAsync("localhost", null, SslProtocolSupport.DefaultSslProtocols, false));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ public static IEnumerable<object[]> ProtocolMismatchData()
#pragma warning restore 0618
yield return new object[] { SslProtocols.Tls, SslProtocols.Tls11, typeof(AuthenticationException) };
yield return new object[] { SslProtocols.Tls, SslProtocols.Tls12, typeof(AuthenticationException) };
yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls, typeof(TimeoutException) };
yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(TimeoutException) };
yield return new object[] { SslProtocols.Tls11, SslProtocols.Tls12, typeof(AuthenticationException) };
yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls, typeof(TimeoutException) };
yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls11, typeof(TimeoutException) };
yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(TimeoutException) };
yield return new object[] { SslProtocols.Tls12, SslProtocols.Tls11, TestConfiguration.SupportsVersionAlerts ? typeof(AuthenticationException) : typeof(TimeoutException) };
}

#region Helpers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task ServerNoEncryption_ClientRequireEncryption_NoConnect()

using (var sslStream = new SslStream(client.GetStream(), false, AllowAnyServerCertificate, null, EncryptionPolicy.RequireEncryption))
{
await Assert.ThrowsAsync<IOException>(() =>
await Assert.ThrowsAsync(TestConfiguration.SupportsHandshakeAlerts ? typeof(AuthenticationException) : typeof(IOException), () =>
sslStream.AuthenticateAsClientAsync("localhost", null, SslProtocolSupport.DefaultSslProtocols, false));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public async Task ServerRequireEncryption_ClientNoEncryption_NoConnect()
await client.ConnectAsync(serverRequireEncryption.RemoteEndPoint.Address, serverRequireEncryption.RemoteEndPoint.Port);
using (var sslStream = new SslStream(client.GetStream(), false, AllowAnyServerCertificate, null, EncryptionPolicy.NoEncryption))
{
await Assert.ThrowsAsync<IOException>(() =>
await Assert.ThrowsAsync(TestConfiguration.SupportsHandshakeAlerts ? typeof(AuthenticationException) : typeof(IOException), () =>
sslStream.AuthenticateAsClientAsync("localhost", null, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ public async Task SslStream_StreamToStream_Alpn_NonMatchingProtocols_Fail()
// Test alpn failure only on platforms that supports ALPN.
if (BackendSupportsAlpn)
{
Task t1 = Assert.ThrowsAsync<IOException>(() => clientStream.AuthenticateAsClientAsync(clientOptions, CancellationToken.None));
// schannel sends alert on ALPN failure, openssl does not.
Task t1 = Assert.ThrowsAsync(TestConfiguration.SupportsAlpnAlerts ? typeof(AuthenticationException) : typeof(IOException), () =>
clientStream.AuthenticateAsClientAsync(clientOptions, CancellationToken.None));

try
{
await serverStream.AuthenticateAsServerAsync(serverOptions, CancellationToken.None);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ internal static class TestConfiguration

public static bool SupportsNullEncryption { get { return s_supportsNullEncryption.Value; } }

public static bool SupportsHandshakeAlerts { get { return RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.Windows); } }

public static bool SupportsAlpnAlerts { get { return RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) && PlatformDetection.OpenSslVersion.CompareTo(new Version(1,1,0)) >= 0); } }

public static bool SupportsVersionAlerts { get { return RuntimeInformation.IsOSPlatform(OSPlatform.Linux) && PlatformDetection.OpenSslVersion.CompareTo(new Version(1,1,0)) >= 0; } }

public static Task WhenAllOrAnyFailedWithTimeout(params Task[] tasks)
=> tasks.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);

Expand Down

0 comments on commit 35c27fb

Please sign in to comment.