Skip to content

Commit

Permalink
no way to cancel receive operation
Browse files Browse the repository at this point in the history
Solution: add cancellation token to all receive operation (thread safe only, for now)
  • Loading branch information
somdoron committed May 20, 2020
1 parent 4cd27a1 commit 2658130
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 95 deletions.
14 changes: 14 additions & 0 deletions src/NetMQ.Tests/ClientServer.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using NetMQ;
using NetMQ.Sockets;
using Xunit;
Expand Down Expand Up @@ -59,5 +62,16 @@ public async void Async()
var serverMsg = await client.ReceiveStringAsync();
Assert.Equal("World", serverMsg);
}

[Fact]
public async void AsyncWithCancellationToken()
{
using CancellationTokenSource source = new CancellationTokenSource();
using var server = new ServerSocket();

source.CancelAfter(100);

await Assert.ThrowsAsync<OperationCanceledException>(async () => await server.ReceiveStringAsync(source.Token));
}
}
}
7 changes: 6 additions & 1 deletion src/NetMQ/Core/CommandType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ internal enum CommandType
/// <summary>
/// Send to reaper to stop the reaper immediatly
/// </summary>
ForceStop
ForceStop,

/// <summary>
/// Send a cancellation request to the socket from a cancellation token
/// </summary>
CancellationRequested
}
}
21 changes: 16 additions & 5 deletions src/NetMQ/Core/SocketBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ You should have received a copy of the GNU Lesser General Public License
using System.Diagnostics;
using System.Linq;
using System.Net.Sockets;
using System.Threading;
using AsyncIO;
using JetBrains.Annotations;
using NetMQ.Core.Patterns;
Expand Down Expand Up @@ -982,6 +983,7 @@ public bool TrySend(ref Msg msg, TimeSpan timeout, bool more)
/// </summary>
/// <param name="msg">the <c>Msg</c> to read the received message into</param>
/// <param name="timeout">this controls whether the call blocks, and for how long.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns><c>true</c> if successful, <c>false</c> if it timed out</returns>
/// <remarks>
/// For <paramref name="timeout"/>, there are three categories of value:
Expand All @@ -993,7 +995,7 @@ public bool TrySend(ref Msg msg, TimeSpan timeout, bool more)
/// </remarks>
/// <exception cref="FaultException">the Msg must already have been uninitialised</exception>
/// <exception cref="TerminatingException">The socket must not already be stopped.</exception>
public bool TryRecv(ref Msg msg, TimeSpan timeout)
public bool TryRecv(ref Msg msg, TimeSpan timeout, CancellationToken cancellationToken = default)
{
Lock();
try
Expand Down Expand Up @@ -1056,7 +1058,10 @@ public bool TryRecv(ref Msg msg, TimeSpan timeout)
bool block = m_ticks != 0;
while (true)
{
ProcessCommands(block ? timeoutMillis : 0, false);
if (cancellationToken.IsCancellationRequested)
return false;

ProcessCommands(block ? timeoutMillis : 0, false, cancellationToken);

isMessageAvailable = XRecv(ref msg);
if (isMessageAvailable)
Expand Down Expand Up @@ -1168,15 +1173,21 @@ internal void StartReaping([NotNull] Poller poller)
/// </summary>
/// <param name="timeout">how much time to allow to wait for a command, before returning (in milliseconds)</param>
/// <param name="throttle">if true - throttle the rate of command-execution by doing only one per call</param>
/// <param name="cancellationToken">allows the caller to cancel the process commands operation</param>
/// <exception cref="TerminatingException">The Ctx context must not already be terminating.</exception>
private void ProcessCommands(int timeout, bool throttle)
private void ProcessCommands(int timeout, bool throttle, CancellationToken cancellationToken = default)
{
bool found;
Command command;
if (timeout != 0)
{
// If we are asked to wait, simply ask mailbox to wait.
found = m_mailbox.TryRecv(timeout, out command);
if (cancellationToken.CanBeCanceled)
{
using var registration = cancellationToken.Register(SendCancellationRequested);
found = m_mailbox.TryRecv(timeout, out command);
}
else
found = m_mailbox.TryRecv(timeout, out command);
}
else
{
Expand Down
19 changes: 19 additions & 0 deletions src/NetMQ/Core/ZObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ protected void SendDone()
m_ctx.SendCommand(Ctx.TermTid, new Command(null, CommandType.Done));
}

/// <summary>
/// Socket sends a CancellationRequested command to itself when a CancellationToken has been cancelled
/// </summary>
protected void SendCancellationRequested()
{
SendCommand(new Command(this, CommandType.CancellationRequested, null));
}

/// <summary>
/// Send the given Command, on that commands Destination thread.
/// </summary>
Expand Down Expand Up @@ -324,6 +332,10 @@ public void ProcessCommand([NotNull] Command cmd)
case CommandType.ForceStop:
ProcessForceStop();
break;

case CommandType.CancellationRequested:
ProcessCancellationRequested();
break;

default:
throw new ArgumentException();
Expand Down Expand Up @@ -468,6 +480,13 @@ protected virtual void ProcessSeqnum()
throw new NotSupportedException();
}

/// <summary>
/// Handler for cancellation requested
/// </summary>
protected virtual void ProcessCancellationRequested()
{
}

#endregion
}
}
77 changes: 53 additions & 24 deletions src/NetMQ/ReceiveThreadSafeSocketExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;

Expand All @@ -18,19 +19,26 @@ public static class ReceiveThreadSafeSocketExtensions
/// Receive a bytes from <paramref name="socket"/>, blocking until one arrives.
/// </summary>
/// <param name="socket">The socket to receive from.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns>The content of the received message.</returns>
/// <exception cref="System.OperationCanceledException">The token has had cancellation requested.</exception>
[NotNull]
public static byte[] ReceiveBytes([NotNull] this IThreadSafeInSocket socket)
public static byte[] ReceiveBytes([NotNull] this IThreadSafeInSocket socket,
CancellationToken cancellationToken = default)
{
var msg = new Msg();
msg.InitEmpty();

socket.Receive(ref msg);

var data = msg.CloneData();

msg.Close();
return data;
try
{
socket.Receive(ref msg, cancellationToken);
var data = msg.CloneData();
return data;
}
finally
{
msg.Close();
}
}

#endregion
Expand Down Expand Up @@ -60,14 +68,16 @@ public static bool TryReceiveBytes([NotNull] this IThreadSafeInSocket socket, ou
/// <param name="socket">The socket to receive from.</param>
/// <param name="timeout">The maximum period of time to wait for a message to become available.</param>
/// <param name="bytes">The content of the received message, or <c>null</c> if no message was available.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns><c>true</c> if a message was available, otherwise <c>false</c>.</returns>
/// <remarks>The method would return false if cancellation has had requested.</remarks>
public static bool TryReceiveBytes([NotNull] this IThreadSafeInSocket socket, TimeSpan timeout,
out byte[] bytes)
out byte[] bytes, CancellationToken cancellationToken = default)
{
var msg = new Msg();
msg.InitEmpty();

if (!socket.TryReceive(ref msg, timeout))
if (!socket.TryReceive(ref msg, timeout, cancellationToken))
{
msg.Close();
bytes = null;
Expand All @@ -88,15 +98,20 @@ public static bool TryReceiveBytes([NotNull] this IThreadSafeInSocket socket, Ti
/// Receive a bytes from <paramref name="socket"/> asynchronously.
/// </summary>
/// <param name="socket">The socket to receive from.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns>The content of the received message.</returns>
public static ValueTask<byte[]> ReceiveBytesAsync([NotNull] this IThreadSafeInSocket socket)
/// <exception cref="System.OperationCanceledException">The token has had cancellation requested.</exception>
public static ValueTask<byte[]> ReceiveBytesAsync([NotNull] this IThreadSafeInSocket socket,
CancellationToken cancellationToken = default)
{
if (TryReceiveBytes(socket, out var bytes))
return new ValueTask<byte[]>(bytes);

// TODO: this is a hack, eventually we need kind of IO ThreadPool for thread-safe socket to wait on asynchronously
// and probably implement IValueTaskSource
return new ValueTask<byte[]>(Task.Factory.StartNew(socket.ReceiveBytes, TaskCreationOptions.LongRunning));
// TODO: should we avoid lambda here as it cause heap allocation for the environment?
return new ValueTask<byte[]>(Task.Factory.StartNew(() => socket.ReceiveBytes(cancellationToken),
cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default));
}

#endregion
Expand All @@ -111,27 +126,32 @@ public static ValueTask<byte[]> ReceiveBytesAsync([NotNull] this IThreadSafeInSo
/// Receive a string from <paramref name="socket"/>, blocking until one arrives, and decode using <see cref="SendReceiveConstants.DefaultEncoding"/>.
/// </summary>
/// <param name="socket">The socket to receive from.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns>The content of the received message.</returns>
public static string ReceiveString([NotNull] this IThreadSafeInSocket socket)
/// <exception cref="System.OperationCanceledException">The token has had cancellation requested.</exception>
public static string ReceiveString([NotNull] this IThreadSafeInSocket socket,
CancellationToken cancellationToken = default)
{
return socket.ReceiveString(SendReceiveConstants.DefaultEncoding);
return socket.ReceiveString(SendReceiveConstants.DefaultEncoding, cancellationToken);
}

/// <summary>
/// Receive a string from <paramref name="socket"/>, blocking until one arrives, and decode using <paramref name="encoding"/>.
/// </summary>
/// <param name="socket">The socket to receive from.</param>
/// <param name="encoding">The encoding used to convert the data to a string.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns>The content of the received message.</returns>
public static string ReceiveString([NotNull] this IThreadSafeInSocket socket, [NotNull] Encoding encoding)
/// <exception cref="System.OperationCanceledException">The token has had cancellation requested.</exception>
public static string ReceiveString([NotNull] this IThreadSafeInSocket socket, [NotNull] Encoding encoding,
CancellationToken cancellationToken = default)
{
var msg = new Msg();
msg.InitEmpty();

socket.Receive(ref msg);

try
{
socket.Receive(ref msg, cancellationToken);
return msg.Size > 0
? msg.GetString(encoding)
: string.Empty;
Expand Down Expand Up @@ -182,11 +202,14 @@ public static bool TryReceiveString([NotNull] this IThreadSafeInSocket socket, [
/// </summary>
/// <param name="socket">The socket to receive from.</param>
/// <param name="timeout">The maximum period of time to wait for a message to become available.</param>
/// <param name="str">The content of the received message, or <c>null</c> if no message was available.</param>
/// <param name="str">The conent of the received message, or <c>null</c> if no message was available.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns><c>true</c> if a message was available, otherwise <c>false</c>.</returns>
public static bool TryReceiveString([NotNull] this IThreadSafeInSocket socket, TimeSpan timeout, out string str)
/// <remarks>The method would return false if cancellation has had requested.</remarks>
public static bool TryReceiveString([NotNull] this IThreadSafeInSocket socket, TimeSpan timeout, out string str,
CancellationToken cancellationToken = default)
{
return socket.TryReceiveString(timeout, SendReceiveConstants.DefaultEncoding, out str);
return socket.TryReceiveString(timeout, SendReceiveConstants.DefaultEncoding, out str, cancellationToken);
}

/// <summary>
Expand All @@ -197,14 +220,16 @@ public static bool TryReceiveString([NotNull] this IThreadSafeInSocket socket, T
/// <param name="timeout">The maximum period of time to wait for a message to become available.</param>
/// <param name="encoding">The encoding used to convert the data to a string.</param>
/// <param name="str">The content of the received message, or <c>null</c> if no message was available.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns><c>true</c> if a message was available, otherwise <c>false</c>.</returns>
/// <remarks>The method would return false if cancellation has had requested.</remarks>
public static bool TryReceiveString([NotNull] this IThreadSafeInSocket socket, TimeSpan timeout,
[NotNull] Encoding encoding, out string str)
[NotNull] Encoding encoding, out string str, CancellationToken cancellationToken = default)
{
var msg = new Msg();
msg.InitEmpty();

if (socket.TryReceive(ref msg, timeout))
if (socket.TryReceive(ref msg, timeout, cancellationToken))
{
try
{
Expand All @@ -215,7 +240,7 @@ public static bool TryReceiveString([NotNull] this IThreadSafeInSocket socket, T
}
finally
{
msg.Close();
msg.Close();
}
}

Expand All @@ -232,15 +257,19 @@ public static bool TryReceiveString([NotNull] this IThreadSafeInSocket socket, T
/// Receive a string from <paramref name="socket"/> asynchronously.
/// </summary>
/// <param name="socket">The socket to receive from.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <returns>The content of the received message.</returns>
public static ValueTask<string> ReceiveStringAsync([NotNull] this IThreadSafeInSocket socket)
/// <exception cref="System.OperationCanceledException">The token has had cancellation requested.</exception>
public static ValueTask<string> ReceiveStringAsync([NotNull] this IThreadSafeInSocket socket,
CancellationToken cancellationToken = default)
{
if (TryReceiveString(socket, out var msg))
return new ValueTask<string>(msg);

// TODO: this is a hack, eventually we need kind of IO ThreadPool for thread-safe socket to wait on asynchronously
// and probably implement IValueTaskSource
return new ValueTask<string>(Task.Factory.StartNew(socket.ReceiveString, TaskCreationOptions.LongRunning));
return new ValueTask<string>(Task.Factory.StartNew(() => socket.ReceiveString(cancellationToken),
cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default));
}

#endregion
Expand Down
Loading

0 comments on commit 2658130

Please sign in to comment.