Skip to content

Commit

Permalink
Tweak SslApplicationProtocol (dotnet/corefx#36021)
Browse files Browse the repository at this point in the history
* Tweak SslApplicationProtocol

- It currently stores a `ReadOnlyMemory<byte>`.  That just adds unnecessary expense: we can instead just store the provided `byte[]`.
- The most common values are those exposed statically: Http2 and Http11, but ToString on those results in creating a new string each time.  Special-case them.
- Constructing an SslApplicationProtocol with a null string results in an ArgumentNullException being thrown with the wrong parameter name.  Fix it.
- Miscellaneous cleanup on the file.

* Address PR feedback


Commit migrated from dotnet/corefx@af6e226
  • Loading branch information
stephentoub authored Mar 14, 2019
1 parent c31730d commit 22435bc
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ public static unsafe byte[] ToByteArray(List<SslApplicationProtocol> application
long protocolListSize = 0;
for (int i = 0; i < applicationProtocols.Count; i++)
{
if (applicationProtocols[i].Protocol.Length == 0 || applicationProtocols[i].Protocol.Length > byte.MaxValue)
int protocolLength = applicationProtocols[i].Protocol.Length;

if (protocolLength == 0 || protocolLength > byte.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}

protocolListSize += applicationProtocols[i].Protocol.Length + 1;
protocolListSize += protocolLength + 1;

if (protocolListSize > short.MaxValue)
{
Expand All @@ -51,9 +53,10 @@ public static unsafe byte[] ToByteArray(List<SslApplicationProtocol> application

for (int i = 0; i < applicationProtocols.Count; i++)
{
buffer[index++] = (byte)applicationProtocols[i].Protocol.Length;
applicationProtocols[i].Protocol.Span.CopyTo(buffer.AsSpan(index));
index += applicationProtocols[i].Protocol.Length;
ReadOnlySpan<byte> protocol = applicationProtocols[i].Protocol.Span;
buffer[index++] = (byte)protocol.Length;
protocol.CopyTo(buffer.AsSpan(index));
index += protocol.Length;
}

return buffer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,139 +2,111 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Diagnostics;
using System.Text;

namespace System.Net.Security
{
public readonly struct SslApplicationProtocol : IEquatable<SslApplicationProtocol>
{
private readonly ReadOnlyMemory<byte> _readOnlyProtocol;
private static readonly Encoding s_utf8 = Encoding.GetEncoding(Encoding.UTF8.CodePage, EncoderFallback.ExceptionFallback, DecoderFallback.ExceptionFallback);
private static readonly byte[] s_http2Utf8 = new byte[] { 0x68, 0x32 }; // "h2"
private static readonly byte[] s_http11Utf8 = new byte[] { 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31 }; // "http/1.1"

// Refer IANA on ApplicationProtocols: https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids
// Refer to IANA on ApplicationProtocols: https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids
// h2
public static readonly SslApplicationProtocol Http2 = new SslApplicationProtocol(new byte[] { 0x68, 0x32 }, false);
public static readonly SslApplicationProtocol Http2 = new SslApplicationProtocol(s_http2Utf8, copy: false);
// http/1.1
public static readonly SslApplicationProtocol Http11 = new SslApplicationProtocol(new byte[] { 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31 }, false);
public static readonly SslApplicationProtocol Http11 = new SslApplicationProtocol(s_http11Utf8, copy: false);

private readonly byte[] _readOnlyProtocol;

internal SslApplicationProtocol(byte[] protocol, bool copy)
{
if (protocol == null)
{
throw new ArgumentNullException(nameof(protocol));
}
Debug.Assert(protocol != null);

// RFC 7301 states protocol size <= 255 bytes.
if (protocol.Length == 0 || protocol.Length > 255)
{
throw new ArgumentException(SR.net_ssl_app_protocol_invalid, nameof(protocol));
}

if (copy)
{
byte[] temp = new byte[protocol.Length];
Array.Copy(protocol, 0, temp, 0, protocol.Length);
_readOnlyProtocol = new ReadOnlyMemory<byte>(temp);
}
else
{
_readOnlyProtocol = new ReadOnlyMemory<byte>(protocol);
}
_readOnlyProtocol = copy ?
protocol.AsSpan().ToArray() :
protocol;
}

public SslApplicationProtocol(byte[] protocol) : this(protocol, true) { }

public SslApplicationProtocol(string protocol) : this(s_utf8.GetBytes(protocol), copy: false) { }

public ReadOnlyMemory<byte> Protocol
public SslApplicationProtocol(byte[] protocol) :
this(protocol ?? throw new ArgumentNullException(nameof(protocol)), copy: true)
{
get => _readOnlyProtocol;
}

public bool Equals(SslApplicationProtocol other)
public SslApplicationProtocol(string protocol) :
this(s_utf8.GetBytes(protocol ?? throw new ArgumentNullException(nameof(protocol))), copy: false)
{
if (_readOnlyProtocol.Length != other._readOnlyProtocol.Length)
return false;

return (_readOnlyProtocol.IsEmpty && other._readOnlyProtocol.IsEmpty) ||
_readOnlyProtocol.Span.SequenceEqual(other._readOnlyProtocol.Span);
}

public override bool Equals(object obj)
{
if (obj is SslApplicationProtocol protocol)
{
return Equals(protocol);
}
public ReadOnlyMemory<byte> Protocol => _readOnlyProtocol;

return false;
}
public bool Equals(SslApplicationProtocol other) =>
((ReadOnlySpan<byte>)_readOnlyProtocol).SequenceEqual(other._readOnlyProtocol);

public override bool Equals(object obj) => obj is SslApplicationProtocol protocol && Equals(protocol);

public override int GetHashCode()
{
if (_readOnlyProtocol.Length == 0)
byte[] arr = _readOnlyProtocol;
if (arr == null)
{
return 0;
}

int hash1 = 0;
ReadOnlySpan<byte> pSpan = _readOnlyProtocol.Span;
for (int i = 0; i < _readOnlyProtocol.Length; i++)
int hash = 0;
for (int i = 0; i < arr.Length; i++)
{
hash1 = ((hash1 << 5) + hash1) ^ pSpan[i];
hash = ((hash << 5) + hash) ^ arr[i];
}

return hash1;
return hash;
}

public override string ToString()
{
byte[] arr = _readOnlyProtocol;
try
{
if (_readOnlyProtocol.Length == 0)
{
return null;
}

return s_utf8.GetString(_readOnlyProtocol.Span);
return
arr is null ? string.Empty :
ReferenceEquals(arr, s_http2Utf8) ? "h2" :
ReferenceEquals(arr, s_http11Utf8) ? "http/1.1" :
s_utf8.GetString(arr);
}
catch
{
// In case of decoding errors, return the byte values as hex string.
int byteCharsLength = _readOnlyProtocol.Length * 5;
char[] byteChars = new char[byteCharsLength];
char[] byteChars = new char[arr.Length * 5];
int index = 0;

ReadOnlySpan<byte> pSpan = _readOnlyProtocol.Span;
for (int i = 0; i < byteCharsLength; i += 5)

for (int i = 0; i < byteChars.Length; i += 5)
{
byte b = pSpan[index++];
byte b = arr[index++];
byteChars[i] = '0';
byteChars[i + 1] = 'x';
byteChars[i + 2] = GetHexValue(Math.DivRem(b, 16, out int rem));
byteChars[i + 3] = GetHexValue(rem);
byteChars[i + 4] = ' ';
}

return new string(byteChars, 0, byteCharsLength - 1);
}
}
return new string(byteChars, 0, byteChars.Length - 1);

static char GetHexValue(int i)
{
if (i < 10)
return (char)(i + '0');

return (char)(i - 10 + 'a');
static char GetHexValue(int i) => (char)(i < 10 ? i + '0' : i - 10 + 'a');
}
}

public static bool operator ==(SslApplicationProtocol left, SslApplicationProtocol right)
{
return left.Equals(right);
}
public static bool operator ==(SslApplicationProtocol left, SslApplicationProtocol right) =>
left.Equals(right);

public static bool operator !=(SslApplicationProtocol left, SslApplicationProtocol right)
{
return !(left == right);
}
public static bool operator !=(SslApplicationProtocol left, SslApplicationProtocol right) =>
!(left == right);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ public void Constructor_Overloads_Succeeds()
SslApplicationProtocol defaultProtocol = default;
Assert.True(defaultProtocol.Protocol.IsEmpty);

Assert.Throws<ArgumentNullException>(() => { new SslApplicationProtocol((byte[])null); });
Assert.Throws<ArgumentNullException>(() => { new SslApplicationProtocol((string)null); });
Assert.Throws<ArgumentException>(() => { new SslApplicationProtocol(new byte[] { }); });
Assert.Throws<ArgumentException>(() => { new SslApplicationProtocol(string.Empty); });
Assert.Throws<ArgumentException>(() => { new SslApplicationProtocol(Encoding.UTF8.GetBytes(new string('a', 256))); });
Assert.Throws<ArgumentException>(() => { new SslApplicationProtocol(new string('a', 256)); });
AssertExtensions.Throws<ArgumentNullException>("protocol", () => { new SslApplicationProtocol((byte[])null); });
AssertExtensions.Throws<ArgumentNullException>("protocol", () => { new SslApplicationProtocol((string)null); });
AssertExtensions.Throws<ArgumentException>("protocol", () => { new SslApplicationProtocol(new byte[] { }); });
AssertExtensions.Throws<ArgumentException>("protocol", () => { new SslApplicationProtocol(string.Empty); });
AssertExtensions.Throws<ArgumentException>("protocol", () => { new SslApplicationProtocol(Encoding.UTF8.GetBytes(new string('a', 256))); });
AssertExtensions.Throws<ArgumentException>("protocol", () => { new SslApplicationProtocol(new string('a', 256)); });
Assert.Throws<EncoderFallbackException>(() => { new SslApplicationProtocol("\uDC00"); });
}

Expand Down Expand Up @@ -74,16 +74,11 @@ public void InEquality_Tests_Succeeds(SslApplicationProtocol left, SslApplicatio
[Fact]
public void ToString_Rendering_Succeeds()
{
const string expected = "hello";
SslApplicationProtocol protocol = new SslApplicationProtocol(expected);
Assert.Equal(expected, protocol.ToString());

byte[] bytes = new byte[] { 0x0B, 0xEE };
protocol = new SslApplicationProtocol(bytes);
Assert.Equal("0x0b 0xee", protocol.ToString());

protocol = default;
Assert.Null(protocol.ToString());
Assert.Equal("http/1.1", SslApplicationProtocol.Http11.ToString());
Assert.Equal("h2", SslApplicationProtocol.Http2.ToString());
Assert.Equal("hello", new SslApplicationProtocol("hello").ToString());
Assert.Equal("0x0b 0xee", new SslApplicationProtocol(new byte[] { 0x0B, 0xEE }).ToString());
Assert.Equal(string.Empty, default(SslApplicationProtocol).ToString());
}

public static IEnumerable<object[]> Protocol_Equality_TestData()
Expand Down

0 comments on commit 22435bc

Please sign in to comment.