Skip to content

Commit

Permalink
Remove invalid Unsafe.As from array helpers (dotnet#99778)
Browse files Browse the repository at this point in the history
* Remove UB from helpers

* Cleanup more helpers

* Fix Test.CoreLib

* Cleanup

---------

Co-authored-by: Jan Kotas <[email protected]>
  • Loading branch information
MichalPetryka and jkotas authored Mar 21, 2024
1 parent 1c73fa7 commit 310b824
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;

namespace System.Runtime.CompilerServices
{
[StackTraceHidden]
[DebuggerStepThrough]
internal static unsafe class CastHelpers
{
// In coreclr the table is allocated and written to on the native side.
Expand All @@ -24,14 +23,12 @@ internal static unsafe class CastHelpers
private static extern ref byte Unbox_Helper(void* toTypeHnd, object obj);

[MethodImpl(MethodImplOptions.InternalCall)]
private static extern void WriteBarrier(ref object? dst, object obj);
private static extern void WriteBarrier(ref object? dst, object? obj);

// IsInstanceOf test used for unusual cases (naked type parameters, variant generic types)
// Unlike the IsInstanceOfInterface and IsInstanceOfClass functions,
// this test must deal with all kinds of type tests
[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static object? IsInstanceOfAny(void* toTypeHnd, object? obj)
{
if (obj != null)
Expand Down Expand Up @@ -63,8 +60,6 @@ internal static unsafe class CastHelpers
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static object? IsInstanceOfInterface(void* toTypeHnd, object? obj)
{
const int unrollSize = 4;
Expand Down Expand Up @@ -134,8 +129,6 @@ internal static unsafe class CastHelpers
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static object? IsInstanceOfClass(void* toTypeHnd, object? obj)
{
if (obj == null || RuntimeHelpers.GetMethodTable(obj) == toTypeHnd)
Expand Down Expand Up @@ -184,8 +177,6 @@ internal static unsafe class CastHelpers
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
[MethodImpl(MethodImplOptions.NoInlining)]
private static object? IsInstance_Helper(void* toTypeHnd, object obj)
{
Expand All @@ -207,8 +198,6 @@ internal static unsafe class CastHelpers
// Unlike the ChkCastInterface and ChkCastClass functions,
// this test must deal with all kinds of type tests
[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
internal static object? ChkCastAny(void* toTypeHnd, object? obj)
{
CastResult result;
Expand Down Expand Up @@ -237,8 +226,6 @@ internal static unsafe class CastHelpers
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
[MethodImpl(MethodImplOptions.NoInlining)]
private static object? ChkCast_Helper(void* toTypeHnd, object obj)
{
Expand All @@ -253,8 +240,6 @@ internal static unsafe class CastHelpers
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static object? ChkCastInterface(void* toTypeHnd, object? obj)
{
const int unrollSize = 4;
Expand Down Expand Up @@ -321,8 +306,6 @@ internal static unsafe class CastHelpers
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static object? ChkCastClass(void* toTypeHnd, object? obj)
{
if (obj == null || RuntimeHelpers.GetMethodTable(obj) == toTypeHnd)
Expand All @@ -336,8 +319,6 @@ internal static unsafe class CastHelpers
// Optimized helper for classes. Assumes that the trivial cases
// has been taken care of by the inlined check
[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static object? ChkCastClassSpecial(void* toTypeHnd, object obj)
{
MethodTable* mt = RuntimeHelpers.GetMethodTable(obj);
Expand Down Expand Up @@ -384,52 +365,53 @@ internal static unsafe class CastHelpers
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static ref byte Unbox(void* toTypeHnd, object obj)
{
// this will throw NullReferenceException if obj is null, attributed to the user code, as expected.
// This will throw NullReferenceException if obj is null.
if (RuntimeHelpers.GetMethodTable(obj) == toTypeHnd)
return ref obj.GetRawData();

return ref Unbox_Helper(toTypeHnd, obj);
}

internal struct ArrayElement
[DebuggerHidden]
private static void ThrowIndexOutOfRangeException()
{
public object? Value;
throw new IndexOutOfRangeException();
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static ref object? ThrowArrayMismatchException()
private static void ThrowArrayMismatchException()
{
throw new ArrayTypeMismatchException();
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static ref object? LdelemaRef(Array array, nint index, void* type)
private static ref object? LdelemaRef(object?[] array, nint index, void* type)
{
// this will throw appropriate exceptions if array is null or access is out of range.
ref object? element = ref Unsafe.As<ArrayElement[]>(array)[index].Value;
// This will throw NullReferenceException if array is null.
if ((nuint)index >= (uint)array.Length)
ThrowIndexOutOfRangeException();

Debug.Assert(index >= 0);
ref object? element = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(array), index);
void* elementType = RuntimeHelpers.GetMethodTable(array)->ElementType;

if (elementType == type)
return ref element;
if (elementType != type)
ThrowArrayMismatchException();

return ref ThrowArrayMismatchException();
return ref element;
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static void StelemRef(Array array, nint index, object? obj)
private static void StelemRef(object?[] array, nint index, object? obj)
{
// this will throw appropriate exceptions if array is null or access is out of range.
ref object? element = ref Unsafe.As<ArrayElement[]>(array)[index].Value;
// This will throw NullReferenceException if array is null.
if ((nuint)index >= (uint)array.Length)
ThrowIndexOutOfRangeException();

Debug.Assert(index >= 0);
ref object? element = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(array), index);
void* elementType = RuntimeHelpers.GetMethodTable(array)->ElementType;

if (obj == null)
Expand All @@ -454,8 +436,6 @@ private static void StelemRef(Array array, nint index, object? obj)
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
[MethodImpl(MethodImplOptions.NoInlining)]
private static void StelemRef_Helper(ref object? element, void* elementType, object obj)
{
Expand All @@ -470,20 +450,17 @@ private static void StelemRef_Helper(ref object? element, void* elementType, obj
}

[DebuggerHidden]
[StackTraceHidden]
[DebuggerStepThrough]
private static void StelemRef_Helper_NoCacheLookup(ref object? element, void* elementType, object obj)
{
Debug.Assert(obj != null);

obj = IsInstanceOfAny_NoCacheLookup(elementType, obj);
if (obj != null)
if (obj == null)
{
WriteBarrier(ref element, obj);
return;
ThrowArrayMismatchException();
}

throw new ArrayTypeMismatchException();
WriteBarrier(ref element, obj);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ internal static int RhEndNoGCRegion()

[RuntimeImport(Redhawk.BaseName, "RhpAssignRef")]
[MethodImpl(MethodImplOptions.InternalCall)]
internal static extern unsafe void RhpAssignRef(ref object address, object obj);
internal static extern unsafe void RhpAssignRef(ref object? address, object? obj);

[MethodImplAttribute(MethodImplOptions.InternalCall)]
[RuntimeImport(Redhawk.BaseName, "RhpGcSafeZeroMemory")]
Expand Down
115 changes: 61 additions & 54 deletions src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/TypeCast.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace System.Runtime
//
/////////////////////////////////////////////////////////////////////////////////////////////////////

[StackTraceHidden]
[DebuggerStepThrough]
[EagerStaticClassConstruction]
internal static class TypeCast
{
Expand Down Expand Up @@ -737,23 +739,69 @@ public static unsafe void CheckArrayStore(object array, object obj)
throw array.GetMethodTable()->GetClasslibException(ExceptionIDs.ArrayTypeMismatch);
}

internal struct ArrayElement
private static unsafe void ThrowIndexOutOfRangeException(object?[] array)
{
public object Value;
// Throw the index out of range exception defined by the classlib, using the input array's MethodTable*
// to find the correct classlib.
throw array.GetMethodTable()->GetClasslibException(ExceptionIDs.IndexOutOfRange);
}

private static unsafe void ThrowArrayMismatchException(object?[] array)
{
// Throw the array type mismatch exception defined by the classlib, using the input array's MethodTable*
// to find the correct classlib.
throw array.GetMethodTable()->GetClasslibException(ExceptionIDs.ArrayTypeMismatch);
}

//
// Array stelem/ldelema helpers with RyuJIT conventions
//

[RuntimeExport("RhpLdelemaRef")]
public static unsafe ref object? LdelemaRef(object?[] array, nint index, MethodTable* elementType)
{
Debug.Assert(array is null || array.GetMethodTable()->IsArray, "first argument must be an array");

#if INPLACE_RUNTIME
// This will throw NullReferenceException if obj is null.
if ((nuint)index >= (uint)array.Length)
ThrowIndexOutOfRangeException(array);

Debug.Assert(index >= 0);
ref object? element = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(array), index);
#else
if (array is null)
{
throw elementType->GetClasslibException(ExceptionIDs.NullReference);
}
if ((nuint)index >= (uint)array.Length)
{
throw elementType->GetClasslibException(ExceptionIDs.IndexOutOfRange);
}
ref object rawData = ref Unsafe.As<byte, object>(ref Unsafe.As<RawArrayData>(array).Data);
ref object element = ref Unsafe.Add(ref rawData, index);
#endif
MethodTable* arrayElemType = array.GetMethodTable()->RelatedParameterType;

if (elementType != arrayElemType)
ThrowArrayMismatchException(array);

return ref element;
}

[RuntimeExport("RhpStelemRef")]
public static unsafe void StelemRef(Array array, nint index, object obj)
public static unsafe void StelemRef(object?[] array, nint index, object? obj)
{
// This is supported only on arrays
Debug.Assert(array.GetMethodTable()->IsArray, "first argument must be an array");
Debug.Assert(array is null || array.GetMethodTable()->IsArray, "first argument must be an array");

#if INPLACE_RUNTIME
// this will throw appropriate exceptions if array is null or access is out of range.
ref object element = ref Unsafe.As<ArrayElement[]>(array)[index].Value;
// This will throw NullReferenceException if obj is null.
if ((nuint)index >= (uint)array.Length)
ThrowIndexOutOfRangeException(array);

Debug.Assert(index >= 0);
ref object? element = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(array), index);
#else
if (array is null)
{
Expand Down Expand Up @@ -796,7 +844,7 @@ public static unsafe void StelemRef(Array array, nint index, object obj)
}

[MethodImpl(MethodImplOptions.NoInlining)]
private static unsafe void StelemRef_Helper(ref object element, MethodTable* elementType, object obj)
private static unsafe void StelemRef_Helper(ref object? element, MethodTable* elementType, object obj)
{
CastResult result = s_castCache.TryGet((nuint)obj.GetMethodTable() + (int)AssignmentVariation.BoxedSource, (nuint)elementType);
if (result == CastResult.CanCast)
Expand All @@ -808,58 +856,17 @@ private static unsafe void StelemRef_Helper(ref object element, MethodTable* ele
StelemRef_Helper_NoCacheLookup(ref element, elementType, obj);
}

private static unsafe void StelemRef_Helper_NoCacheLookup(ref object element, MethodTable* elementType, object obj)
private static unsafe void StelemRef_Helper_NoCacheLookup(ref object? element, MethodTable* elementType, object obj)
{
object? castedObj = IsInstanceOfAny_NoCacheLookup(elementType, obj);
if (castedObj != null)
if (castedObj == null)
{
InternalCalls.RhpAssignRef(ref element, obj);
return;
// Throw the array type mismatch exception defined by the classlib, using the input array's
// MethodTable* to find the correct classlib.
throw elementType->GetClasslibException(ExceptionIDs.ArrayTypeMismatch);
}

// Throw the array type mismatch exception defined by the classlib, using the input array's
// MethodTable* to find the correct classlib.
throw elementType->GetClasslibException(ExceptionIDs.ArrayTypeMismatch);
}

[RuntimeExport("RhpLdelemaRef")]
public static unsafe ref object LdelemaRef(Array array, nint index, IntPtr elementType)
{
Debug.Assert(array is null || array.GetMethodTable()->IsArray, "first argument must be an array");

#if INPLACE_RUNTIME
// this will throw appropriate exceptions if array is null or access is out of range.
ref object element = ref Unsafe.As<ArrayElement[]>(array)[index].Value;
#else
if (array is null)
{
throw ((MethodTable*)elementType)->GetClasslibException(ExceptionIDs.NullReference);
}
if ((uint)index >= (uint)array.Length)
{
throw ((MethodTable*)elementType)->GetClasslibException(ExceptionIDs.IndexOutOfRange);
}
ref object rawData = ref Unsafe.As<byte, object>(ref Unsafe.As<RawArrayData>(array).Data);
ref object element = ref Unsafe.Add(ref rawData, index);
#endif

MethodTable* elemType = (MethodTable*)elementType;
MethodTable* arrayElemType = array.GetMethodTable()->RelatedParameterType;

if (elemType == arrayElemType)
{
return ref element;
}

return ref ThrowArrayMismatchException(array);
}

// This weird structure is for parity with CoreCLR - allows potentially to be tailcalled
private static unsafe ref object ThrowArrayMismatchException(Array array)
{
// Throw the array type mismatch exception defined by the classlib, using the input array's MethodTable*
// to find the correct classlib.
throw array.GetMethodTable()->GetClasslibException(ExceptionIDs.ArrayTypeMismatch);
InternalCalls.RhpAssignRef(ref element, obj);
}

private static unsafe object IsInstanceOfArray(MethodTable* pTargetType, object obj)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace System.Diagnostics
{
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method | AttributeTargets.Constructor, Inherited = false)]
public sealed class DebuggerStepThroughAttribute : Attribute
{
public DebuggerStepThroughAttribute() { }
}
}
Loading

0 comments on commit 310b824

Please sign in to comment.