Skip to content

Commit

Permalink
Fix compression (dotnet#79412)
Browse files Browse the repository at this point in the history
* Fix compression

* Apply suggestions from code review

Co-authored-by: Miha Zupan <[email protected]>

* Adding SendAsync to ref

* fix ws deflate tests

* Check bytes on server side

Co-authored-by: Miha Zupan <[email protected]>
Co-authored-by: Natalia Kondratyeva <[email protected]>
  • Loading branch information
3 people authored Dec 12, 2022
1 parent ddb91f5 commit cab72d3
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public override void Dispose() { }
public override System.Threading.Tasks.ValueTask<System.Net.WebSockets.ValueWebSocketReceiveResult> ReceiveAsync(System.Memory<byte> buffer, System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.Task SendAsync(System.ArraySegment<byte> buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory<byte> buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory<byte> buffer, System.Net.WebSockets.WebSocketMessageType messageType, System.Net.WebSockets.WebSocketMessageFlags messageFlags, System.Threading.CancellationToken cancellationToken) { throw null; }
}
public sealed partial class ClientWebSocketOptions
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType m
public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) =>
ConnectedWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);

public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) =>
ConnectedWebSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken);

public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken) =>
ConnectedWebSocket.ReceiveAsync(buffer, cancellationToken);

Expand Down
85 changes: 85 additions & 0 deletions src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,91 @@ await LoopbackServer.CreateClientAndServerAsync(async uri =>
}), new LoopbackServer.Options { WebSocketEndpoint = true });
}

[ConditionalFact(nameof(WebSocketsSupported))]
public async Task ThrowsWhenContinuationHasDifferentCompressionFlags()
{
var deflateOpt = new WebSocketDeflateOptions
{
ClientMaxWindowBits = 14,
ClientContextTakeover = true,
ServerMaxWindowBits = 14,
ServerContextTakeover = true
};
await LoopbackServer.CreateClientAndServerAsync(async uri =>
{
using var cws = new ClientWebSocket();
using var cts = new CancellationTokenSource(TimeOutMilliseconds);

cws.Options.DangerousDeflateOptions = deflateOpt;
await ConnectAsync(cws, uri, cts.Token);


await cws.SendAsync(Memory<byte>.Empty, WebSocketMessageType.Text, WebSocketMessageFlags.DisableCompression, default);
Assert.Throws<ArgumentException>("messageFlags", () =>
cws.SendAsync(Memory<byte>.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default));
}, server => server.AcceptConnectionAsync(async connection =>
{
var extensionsReply = CreateDeflateOptionsHeader(deflateOpt);
await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply);
}), new LoopbackServer.Options { WebSocketEndpoint = true });
}

[ConditionalFact(nameof(WebSocketsSupported))]
public async Task SendHelloWithDisableCompression()
{
byte[] bytes = "Hello"u8.ToArray();

int prefixLength = 2;
byte[] rawPrefix = new byte[] { 0x81, 0x85 }; // fin=1, rsv=0, opcode=text; mask=1, len=5
int rawRemainingBytes = 9; // mask bytes (4) + payload bytes (5)
byte[] compressedPrefix = new byte[] { 0xc1, 0x87 }; // fin=1, rsv=compressed, opcode=text; mask=1, len=7
int compressedRemainingBytes = 11; // mask bytes (4) + payload bytes (7)

var deflateOpt = new WebSocketDeflateOptions
{
ClientMaxWindowBits = 14,
ClientContextTakeover = true,
ServerMaxWindowBits = 14,
ServerContextTakeover = true
};

await LoopbackServer.CreateClientAndServerAsync(async uri =>
{
using var cws = new ClientWebSocket();
using var cts = new CancellationTokenSource(TimeOutMilliseconds);

cws.Options.DangerousDeflateOptions = deflateOpt;
await ConnectAsync(cws, uri, cts.Token);

await cws.SendAsync(bytes, WebSocketMessageType.Text, true, cts.Token);

WebSocketMessageFlags flags = WebSocketMessageFlags.DisableCompression | WebSocketMessageFlags.EndOfMessage;
await cws.SendAsync(bytes, WebSocketMessageType.Text, flags, cts.Token);
}, server => server.AcceptConnectionAsync(async connection =>
{
var buffer = new byte[compressedRemainingBytes];
var extensionsReply = CreateDeflateOptionsHeader(deflateOpt);
await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply);

// first message is compressed
await ReadExactAsync(buffer, prefixLength);
Assert.Equal(compressedPrefix, buffer[..prefixLength]);
// read rest of the frame
await ReadExactAsync(buffer, compressedRemainingBytes);

// second message is not compressed
await ReadExactAsync(buffer, prefixLength);
Assert.Equal(rawPrefix, buffer[..prefixLength]);
// read rest of the frame
await ReadExactAsync(buffer, rawRemainingBytes);

async Task ReadExactAsync(byte[] buf, int n)
{
await connection.Stream.ReadAtLeastAsync(buf.AsMemory(0, n), n);
}
}), new LoopbackServer.Options { WebSocketEndpoint = true });
}

private static string CreateDeflateOptionsHeader(WebSocketDeflateOptions options)
{
var builder = new StringBuilder();
Expand Down

0 comments on commit cab72d3

Please sign in to comment.