Skip to content

Commit

Permalink
Propagate cancellation tokens to TrySetCanceled in Dataflow (dotnet#8…
Browse files Browse the repository at this point in the history
…0978)

* Propagate cancellation tokens to TrySetCanceled in Dataflow

When the System.Threading.Tasks.Dataflow library was originally written, CancellationTokenSource's TrySetCanceled didn't have an overload that allowed passing in the CancellationToken that was the cause of the cancellation. Now it does, and we no longer build for target platforms that lack the needed overload.  Thus we can update the library to propagate it everywhere that's relevant.  In some cases, to do this well we do need to rely on a newer CancellationToken.Register overload that accepts a delegate which accepts a token, so there's a little bit of ifdef'ing involved still.

While doing this, I also took the opportunity to sprinkle some `static`s onto lambdas, since I was already doing so for some lambdas as part of this fix.

* Fix pipelines handling of cancellation token

Flush operations were synchronously throwing an exception if cancellation was requested prior to the operation.  Cancellation exceptions should always be propagated out through the returned task.

* Fix pre-cancellation in QuicStream.WriteAsync

Cancellation exceptions should flow out through the returned task, not synchronously.
  • Loading branch information
stephentoub authored Jan 28, 2023
1 parent b6cd300 commit 73da129
Show file tree
Hide file tree
Showing 39 changed files with 302 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -495,30 +495,21 @@ e is NotSupportedException ||
}
}

protected async Task AssertCanceledAsync(CancellationToken cancellationToken, Func<Task> testCode)
{
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(testCode);
if (cancellationToken.CanBeCanceled)
{
Assert.Equal(cancellationToken, oce.CancellationToken);
}
}

protected async Task ValidatePrecanceledOperations_ThrowsCancellationException(Stream stream)
{
var cts = new CancellationTokenSource();
cts.Cancel();

if (stream.CanRead)
{
await AssertCanceledAsync(cts.Token, () => stream.ReadAsync(new byte[1], 0, 1, cts.Token));
await AssertCanceledAsync(cts.Token, async () => { await stream.ReadAsync(new Memory<byte>(new byte[1]), cts.Token); });
await AssertExtensions.CanceledAsync(cts.Token, stream.ReadAsync(new byte[1], 0, 1, cts.Token));
await AssertExtensions.CanceledAsync(cts.Token, async () => { await stream.ReadAsync(new Memory<byte>(new byte[1]), cts.Token); });
}

if (stream.CanWrite)
{
await AssertCanceledAsync(cts.Token, () => stream.WriteAsync(new byte[1], 0, 1, cts.Token));
await AssertCanceledAsync(cts.Token, async () => { await stream.WriteAsync(new ReadOnlyMemory<byte>(new byte[1]), cts.Token); });
await AssertExtensions.CanceledAsync(cts.Token, stream.WriteAsync(new byte[1], 0, 1, cts.Token));
await AssertExtensions.CanceledAsync(cts.Token, async () => { await stream.WriteAsync(new ReadOnlyMemory<byte>(new byte[1]), cts.Token); });
}

Exception e = await Record.ExceptionAsync(() => stream.FlushAsync(cts.Token));
Expand All @@ -540,7 +531,7 @@ protected async Task ValidateCancelableReadAsyncTask_AfterInvocation_ThrowsCance
Task<int> t = stream.ReadAsync(new byte[1], 0, 1, cts.Token);

cts.CancelAfter(cancellationDelay);
await AssertCanceledAsync(cts.Token, () => t);
await AssertExtensions.CanceledAsync(cts.Token, t);
}

protected async Task ValidateCancelableReadAsyncValueTask_AfterInvocation_ThrowsCancellationException(Stream stream, int cancellationDelay)
Expand All @@ -555,7 +546,7 @@ protected async Task ValidateCancelableReadAsyncValueTask_AfterInvocation_Throws
Task<int> t = stream.ReadAsync(new byte[1], cts.Token).AsTask();

cts.CancelAfter(cancellationDelay);
await AssertCanceledAsync(cts.Token, () => t);
await AssertExtensions.CanceledAsync(cts.Token, t);
}

protected async Task WhenAllOrAnyFailed(Task task1, Task task2)
Expand Down Expand Up @@ -2584,18 +2575,18 @@ public virtual async Task ReadAsync_CancelPendingRead_DoesntImpactSubsequentRead

cts = new CancellationTokenSource();
cts.Cancel();
await AssertCanceledAsync(cts.Token, () => readable.ReadAsync(new byte[1], 0, 1, cts.Token));
await AssertCanceledAsync(cts.Token, async () => { await readable.ReadAsync(new Memory<byte>(new byte[1]), cts.Token); });
await AssertExtensions.CanceledAsync(cts.Token, readable.ReadAsync(new byte[1], 0, 1, cts.Token));
await AssertExtensions.CanceledAsync(cts.Token, async () => { await readable.ReadAsync(new Memory<byte>(new byte[1]), cts.Token); });

cts = new CancellationTokenSource();
Task<int> t = readable.ReadAsync(new byte[1], 0, 1, cts.Token);
cts.Cancel();
await AssertCanceledAsync(cts.Token, () => t);
await AssertExtensions.CanceledAsync(cts.Token, t);

cts = new CancellationTokenSource();
ValueTask<int> vt = readable.ReadAsync(new Memory<byte>(new byte[1]), cts.Token);
cts.Cancel();
await AssertCanceledAsync(cts.Token, async () => await vt);
await AssertExtensions.CanceledAsync(cts.Token, vt.AsTask());

byte[] buffer = new byte[1];
vt = readable.ReadAsync(new Memory<byte>(buffer));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Sdk;
Expand Down Expand Up @@ -243,6 +244,30 @@ public static void ThrowsIf<T>(bool condition, Action action)
}
}

public static void Canceled(CancellationToken cancellationToken, Action testCode)
{
OperationCanceledException oce = Assert.ThrowsAny<OperationCanceledException>(testCode);
if (cancellationToken.CanBeCanceled)
{
Assert.Equal(cancellationToken, oce.CancellationToken);
}
}

public static Task CanceledAsync(CancellationToken cancellationToken, Task task)
{
Assert.NotNull(task);
return CanceledAsync(cancellationToken, () => task);
}

public static async Task CanceledAsync(CancellationToken cancellationToken, Func<Task> testCode)
{
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(testCode);
if (cancellationToken.CanBeCanceled)
{
Assert.Equal(cancellationToken, oce.CancellationToken);
}
}

private static string AddOptionalUserMessage(string message, string userMessage)
{
if (userMessage == null)
Expand Down
10 changes: 10 additions & 0 deletions src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/Pipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ private void AdvanceCore(int bytesWritten)

internal ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
return new ValueTask<FlushResult>(Task.FromCanceled<FlushResult>(cancellationToken));
}

CompletionData completionData;
ValueTask<FlushResult> result;
lock (SyncObj)
Expand Down Expand Up @@ -1058,6 +1063,11 @@ internal ValueTask<FlushResult> WriteAsync(ReadOnlyMemory<byte> source, Cancella
return new ValueTask<FlushResult>(new FlushResult(isCanceled: false, isCompleted: true));
}

if (cancellationToken.IsCancellationRequested)
{
return new ValueTask<FlushResult>(Task.FromCanceled<FlushResult>(cancellationToken));
}

CompletionData completionData;
ValueTask<FlushResult> result;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ public PipeAwaitable(bool completed, bool useSynchronizationContext)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void BeginOperation(CancellationToken cancellationToken, Action<object?> callback, object? state)
{
cancellationToken.ThrowIfCancellationRequested();

// Don't register if already completed, we would immediately unregistered in ObserveCancellation
if (cancellationToken.CanBeCanceled && !IsCompleted)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,14 @@ public void FlushAsyncReturnsIsCancelOnCancelPendingFlushBeforeGetResult()
}

[Fact]
public void FlushAsyncThrowsIfPassedCanceledCancellationToken()
public async Task FlushAsyncThrowsIfPassedCanceledCancellationToken()
{
var cancellationTokenSource = new CancellationTokenSource();
cancellationTokenSource.Cancel();

Assert.Throws<OperationCanceledException>(() => Pipe.Writer.FlushAsync(cancellationTokenSource.Token));
ValueTask<FlushResult> task = Pipe.Writer.FlushAsync(cancellationTokenSource.Token);
Assert.True(task.IsCanceled);
await AssertExtensions.CanceledAsync(cancellationTokenSource.Token, async () => await task);
}

[Fact]
Expand Down Expand Up @@ -318,7 +320,7 @@ public async Task FlushAsyncThrowsIfPassedCanceledCancellationTokenAndPipeIsAble
// and not only setting IsCompleted flag
var task = Pipe.Reader.ReadAsync().AsTask();

await Assert.ThrowsAsync<OperationCanceledException>(async () => await Pipe.Writer.FlushAsync(cancellationTokenSource.Token));
await AssertExtensions.CanceledAsync(cancellationTokenSource.Token, async () => await Pipe.Writer.FlushAsync(cancellationTokenSource.Token));

Pipe.Writer.Complete();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,11 @@ public ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool completeWrites, Ca
NetEventSource.Info(this, $"{this} Stream writing memory of '{buffer.Length}' bytes while {(completeWrites ? "completing" : "not completing")} writes.");
}

if (_sendTcs.IsCompleted)
if (_sendTcs.IsCompleted && cancellationToken.IsCancellationRequested)
{
// Special case exception type for pre-canceled token while we've already transitioned to a final state and don't need to abort write.
// It must happen before we try to get the value task, since the task source is versioned and each instance must be awaited.
cancellationToken.ThrowIfCancellationRequested();
return ValueTask.FromCanceled(cancellationToken);
}

// Concurrent call, this one lost the race.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,8 @@ public async Task CancelPendingWrite_Reading_DataTransferredFromCorrectWriter()

var cts = new CancellationTokenSource();

Task write1 = c.Writer.WriteAsync(43, cts.Token).AsTask();
Assert.Equal(TaskStatus.WaitingForActivation, write1.Status);
ValueTask write1 = c.Writer.WriteAsync(43, cts.Token);
Assert.False(write1.IsCompleted);

cts.Cancel();

Expand All @@ -490,7 +490,7 @@ public async Task CancelPendingWrite_Reading_DataTransferredFromCorrectWriter()
Assert.Equal(42, await c.Reader.ReadAsync());
Assert.Equal(44, await c.Reader.ReadAsync());

await AssertCanceled(write1, cts.Token);
await AssertExtensions.CanceledAsync(cts.Token, async () => await write1);
await write2;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public async Task TryComplete_ErrorsPropage()
var cts = new CancellationTokenSource();
cts.Cancel();
Assert.True(c.Writer.TryComplete(new OperationCanceledException(cts.Token)));
await AssertCanceled(c.Reader.Completion, cts.Token);
await AssertExtensions.CanceledAsync(cts.Token, c.Reader.Completion);
}

[Fact]
Expand Down Expand Up @@ -450,7 +450,7 @@ public async Task Complete_WithCancellationException_PropagatesToCompletion()
catch (Exception e) { exc = e; }

c.Writer.Complete(exc);
await AssertCanceled(c.Reader.Completion, cts.Token);
await AssertExtensions.CanceledAsync(cts.Token, c.Reader.Completion);
}

[Fact]
Expand Down Expand Up @@ -653,7 +653,7 @@ public async Task ReadAsync_Canceled_CanceledAsynchronously()

cts.Cancel();

await AssertCanceled(r.AsTask(), cts.Token);
await AssertExtensions.CanceledAsync(cts.Token, async () => await r);

if (c.Writer.TryWrite(42))
{
Expand Down Expand Up @@ -760,7 +760,7 @@ public async Task ReadAsync_Canceled_WriteAsyncCompletesNextReader()
var cts = new CancellationTokenSource();
ValueTask<int> r = c.Reader.ReadAsync(cts.Token);
cts.Cancel();
await AssertCanceled(r.AsTask(), cts.Token);
await AssertExtensions.CanceledAsync(cts.Token, async () => await r);
}

for (int i = 0; i < 7; i++)
Expand Down
6 changes: 0 additions & 6 deletions src/libraries/System.Threading.Channels/tests/TestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ protected void AssertSynchronouslyCanceled(Task task, CancellationToken token)
}
}

protected async Task AssertCanceled(Task task, CancellationToken token)
{
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => task);
AssertSynchronouslyCanceled(task, token);
}

protected void AssertSynchronousSuccess<T>(ValueTask<T> task) => Assert.True(task.IsCompletedSuccessfully);
protected void AssertSynchronousSuccess(ValueTask task) => Assert.True(task.IsCompletedSuccessfully);
protected void AssertSynchronousSuccess(Task task) => Assert.Equal(TaskStatus.RanToCompletion, task.Status);
Expand Down
Loading

0 comments on commit 73da129

Please sign in to comment.