From 137a5c40cfecff099f3b5e97c425663ed2e8505d Mon Sep 17 00:00:00 2001 From: Larry Ewing Date: Fri, 14 Sep 2012 13:27:54 -0500 Subject: [PATCH] Rework exception handling in MemoryStream async methods Return faulted tasks for runtime errors when using the MemoryStream Async methods. --- mcs/class/corlib/System.IO/MemoryStream.cs | 59 +++-- .../corlib/System.Threading.Tasks/Task_T.cs | 7 + .../corlib/Test/System.IO/MemoryStreamTest.cs | 205 ++++++++++++++++++ 3 files changed, 257 insertions(+), 14 deletions(-) diff --git a/mcs/class/corlib/System.IO/MemoryStream.cs b/mcs/class/corlib/System.IO/MemoryStream.cs index fa874887d6ff..4ce86df30b61 100644 --- a/mcs/class/corlib/System.IO/MemoryStream.cs +++ b/mcs/class/corlib/System.IO/MemoryStream.cs @@ -376,9 +376,6 @@ public virtual byte [] ToArray () public override void Write (byte [] buffer, int offset, int count) { - if (!canWrite) - throw new NotSupportedException ("Cannot write to this stream."); - if (buffer == null) throw new ArgumentNullException ("buffer"); @@ -391,6 +388,9 @@ public override void Write (byte [] buffer, int offset, int count) CheckIfClosedThrowDisposed (); + if (!CanWrite) + throw new NotSupportedException ("Cannot write to this stream."); + // reordered to avoid possible integer overflow if (position > length - count) Expand (position + count); @@ -436,33 +436,64 @@ public override Task CopyToAsync (Stream destination, int bufferSize, Cancellati public override Task FlushAsync (CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) - return TaskConstants.Canceled; + return TaskConstants.Canceled; - Flush (); - return TaskConstants.Finished; + try { + Flush (); + return TaskConstants.Finished; + } catch (Exception ex) { + return Task.FromException (ex); + } } public override Task ReadAsync (byte[] buffer, int offset, int count, CancellationToken cancellationToken) { + if (buffer == null) + throw new ArgumentNullException ("buffer"); + + if (offset < 0 || count < 0) + throw new ArgumentOutOfRangeException ("offset or count less than zero."); + + if (buffer.Length - offset < count ) + throw new ArgumentException ("offset+count", + "The size of the buffer is less than offset + count."); if (cancellationToken.IsCancellationRequested) return TaskConstants.Canceled; - count = Read (buffer, offset, count); + try { + count = Read (buffer, offset, count); - // Try not to allocate a new task for every buffer read - if (read_task == null || read_task.Result != count) - read_task = Task.FromResult (count); + // Try not to allocate a new task for every buffer read + if (read_task == null || read_task.Result != count) + read_task = Task.FromResult (count); - return read_task; + return read_task; + } catch (Exception ex) { + return Task.FromException (ex); + } } public override Task WriteAsync (byte[] buffer, int offset, int count, CancellationToken cancellationToken) { + if (buffer == null) + throw new ArgumentNullException ("buffer"); + + if (offset < 0 || count < 0) + throw new ArgumentOutOfRangeException (); + + if (buffer.Length - offset < count) + throw new ArgumentException ("offset+count", + "The size of the buffer is less than offset + count."); + if (cancellationToken.IsCancellationRequested) - return TaskConstants.Canceled; + return TaskConstants.Canceled; - Write (buffer, offset, count); - return TaskConstants.Finished; + try { + Write (buffer, offset, count); + return TaskConstants.Finished; + } catch (Exception ex) { + return Task.FromException (ex); + } } #endif } diff --git a/mcs/class/corlib/System.Threading.Tasks/Task_T.cs b/mcs/class/corlib/System.Threading.Tasks/Task_T.cs index f02f17bf9d0f..ce7063309037 100644 --- a/mcs/class/corlib/System.Threading.Tasks/Task_T.cs +++ b/mcs/class/corlib/System.Threading.Tasks/Task_T.cs @@ -323,6 +323,13 @@ public Task ContinueWith (Func, object, TN { return new TaskAwaiter (this); } + + internal static Task FromException (Exception ex) + { + var tcs = new TaskCompletionSource(); + tcs.TrySetException (ex); + return tcs.Task; + } #endif } } diff --git a/mcs/class/corlib/Test/System.IO/MemoryStreamTest.cs b/mcs/class/corlib/Test/System.IO/MemoryStreamTest.cs index 938d578a85f8..2f8e2c05f5a7 100644 --- a/mcs/class/corlib/Test/System.IO/MemoryStreamTest.cs +++ b/mcs/class/corlib/Test/System.IO/MemoryStreamTest.cs @@ -17,6 +17,9 @@ using System.Runtime.Serialization.Formatters.Binary; using System.Text; using System.Threading; +#if NET_4_5 +using System.Threading.Tasks; +#endif using NUnit.Framework; @@ -45,6 +48,54 @@ public override int Read (byte[] buffer, int offset, int count) } } + class ExceptionalStream : MemoryStream + { + public static string Message = "ExceptionalMessage"; + public bool Throw = false; + + public ExceptionalStream () + { + AllowRead = true; + AllowWrite = true; + } + + public ExceptionalStream (byte [] buffer, bool writable) : base (buffer, writable) + { + AllowRead = true; + AllowWrite = true; // we are testing the inherited write property + } + + + public override int Read(byte[] buffer, int offset, int count) + { + if (Throw) + throw new Exception(Message); + + return base.Read(buffer, offset, count); + } + + public override void Write(byte[] buffer, int offset, int count) + { + if (Throw) + throw new Exception(Message); + + base.Write(buffer, offset, count); + } + + public bool AllowRead { get; set; } + public override bool CanRead { get { return AllowRead; } } + + public bool AllowWrite { get; set; } + public override bool CanWrite { get { return AllowWrite; } } + + public override void Flush() + { + if (Throw) + throw new Exception(Message); + + base.Flush(); + } + } MemoryStream testStream; byte [] testStreamData; @@ -1074,6 +1125,150 @@ public void ReadAsync () Assert.AreEqual (1, buffer[0], "#4"); } + [Test] + public void TestAsyncReadExceptions () + { + var buffer = new byte [3]; + using (var stream = new ExceptionalStream ()) { + stream.Write (buffer, 0, buffer.Length); + stream.Write (buffer, 0, buffer.Length); + stream.Position = 0; + var task = stream.ReadAsync (buffer, 0, buffer.Length); + Assert.AreEqual (TaskStatus.RanToCompletion, task.Status, "#1"); + + stream.Throw = true; + task = stream.ReadAsync (buffer, 0, buffer.Length); + Assert.IsTrue (task.IsFaulted, "#2"); + Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#3"); + } + } + + [Test] + public void TestAsyncWriteExceptions () + { + var buffer = new byte [3]; + using (var stream = new ExceptionalStream ()) { + var task = stream.WriteAsync (buffer, 0, buffer.Length); + Assert.AreEqual(TaskStatus.RanToCompletion, task.Status, "#1"); + + stream.Throw = true; + task = stream.WriteAsync (buffer, 0, buffer.Length); + Assert.IsTrue (task.IsFaulted, "#2"); + Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#3"); + } + } + + [Test] + public void TestAsyncArgumentExceptions () + { + var buffer = new byte [3]; + using (var stream = new ExceptionalStream ()) { + var task = stream.WriteAsync (buffer, 0, buffer.Length); + Assert.IsTrue (task.IsCompleted); + + Assert.IsTrue (Throws (() => { stream.WriteAsync (buffer, 0, 1000); }), "#2"); + Assert.IsTrue (Throws (() => { stream.ReadAsync (buffer, 0, 1000); }), "#3"); + Assert.IsTrue (Throws (() => { stream.WriteAsync (buffer, 0, 1000, new CancellationToken (true)); }), "#4"); + Assert.IsTrue (Throws (() => { stream.ReadAsync (buffer, 0, 1000, new CancellationToken (true)); }), "#5"); + Assert.IsTrue (Throws (() => { stream.WriteAsync (null, 0, buffer.Length, new CancellationToken (true)); }), "#6"); + Assert.IsTrue (Throws (() => { stream.ReadAsync (null, 0, buffer.Length, new CancellationToken (true)); }), "#7"); + Assert.IsTrue (Throws (() => { stream.WriteAsync (buffer, 1000, buffer.Length, new CancellationToken (true)); }), "#8"); + Assert.IsTrue (Throws (() => { stream.ReadAsync (buffer, 1000, buffer.Length, new CancellationToken (true)); }), "#9"); + + stream.AllowRead = false; + var read_task = stream.ReadAsync (buffer, 0, buffer.Length); + Assert.AreEqual (TaskStatus.RanToCompletion, read_task.Status, "#8"); + Assert.AreEqual (0, read_task.Result, "#9"); + + stream.Position = 0; + read_task = stream.ReadAsync (buffer, 0, buffer.Length); + Assert.AreEqual (TaskStatus.RanToCompletion, read_task.Status, "#9"); + Assert.AreEqual (3, read_task.Result, "#10"); + + var write_task = stream.WriteAsync (buffer, 0, buffer.Length); + Assert.AreEqual (TaskStatus.RanToCompletion, write_task.Status, "#10"); + + // test what happens when CanRead is overridden + using (var norm = new ExceptionalStream (buffer, false)) { + write_task = norm.WriteAsync (buffer, 0, buffer.Length); + Assert.AreEqual (TaskStatus.RanToCompletion, write_task.Status, "#11"); + } + + stream.AllowWrite = false; + Assert.IsTrue (Throws (() => { stream.Write (buffer, 0, buffer.Length); }), "#12"); + write_task = stream.WriteAsync (buffer, 0, buffer.Length); + Assert.AreEqual (TaskStatus.Faulted, write_task.Status, "#13"); + } + } + + [Test] + public void TestAsyncFlushExceptions () + { + using (var stream = new ExceptionalStream ()) { + var task = stream.FlushAsync (); + Assert.IsTrue (task.IsCompleted, "#1"); + + task = stream.FlushAsync (new CancellationToken(true)); + Assert.IsTrue (task.IsCanceled, "#2"); + + stream.Throw = true; + task = stream.FlushAsync (); + Assert.IsTrue (task.IsFaulted, "#3"); + Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#4"); + + task = stream.FlushAsync (new CancellationToken (true)); + Assert.IsTrue (task.IsCanceled, "#5"); + } + } + + [Test] + public void TestCopyAsync () + { + using (var stream = new ExceptionalStream ()) { + using (var dest = new ExceptionalStream ()) { + byte [] buffer = new byte [] { 12, 13, 8 }; + + stream.Write (buffer, 0, buffer.Length); + stream.Position = 0; + var task = stream.CopyToAsync (dest, 1); + Assert.AreEqual (TaskStatus.RanToCompletion, task.Status); + Assert.AreEqual (3, stream.Length); + Assert.AreEqual (3, dest.Length); + + stream.Position = 0; + dest.Throw = true; + task = stream.CopyToAsync (dest, 1); + Assert.AreEqual (TaskStatus.Faulted, task.Status); + Assert.AreEqual (3, stream.Length); + Assert.AreEqual (3, dest.Length); + } + } + } + + [Test] + public void WritableOverride () + { + var buffer = new byte [3]; + var stream = new MemoryStream (buffer, false); + Assert.IsTrue (Throws (() => { stream.Write (buffer, 0, buffer.Length); }), "#1"); + Assert.IsTrue (Throws (() => { stream.Write (null, 0, buffer.Length); }), "#1.1"); + stream.Close (); + Assert.IsTrue (Throws (() => { stream.Write (buffer, 0, buffer.Length); }), "#2"); + stream = new MemoryStream (buffer, true); + stream.Close (); + Assert.IsFalse (stream.CanWrite, "#3"); + + var estream = new ExceptionalStream (buffer, false); + Assert.IsFalse (Throws (() => { estream.Write (buffer, 0, buffer.Length); }), "#4"); + estream.AllowWrite = false; + estream.Position = 0; + Assert.IsTrue (Throws (() => { estream.Write (buffer, 0, buffer.Length); }), "#5"); + estream.AllowWrite = true; + estream.Close (); + Assert.IsTrue (estream.CanWrite, "#6"); + Assert.IsTrue (Throws (() => { stream.Write (buffer, 0, buffer.Length); }), "#7"); + } + [Test] public void ReadAsync_Canceled () { @@ -1109,6 +1304,16 @@ public void WriteAsync_Canceled () t = testStream.WriteAsync (buffer, 0, buffer.Length); Assert.IsTrue (t.IsCompleted, "#1"); } + + bool Throws (Action a) where T : Exception + { + try { + a (); + return false; + } catch (T) { + return true; + } + } #endif } }