Skip to content

Commit

Permalink
[NRBF] More bug fixes (dotnet#107682)
Browse files Browse the repository at this point in the history
- Don't use `Debug.Fail` not followed by an exception (it may cause problems for apps deployed in Debug)
- avoid Int32 overflow
- throw for unexpected enum values just in case parsing has not rejected them
- validate the number of chars read by BinaryReader.ReadChars
- pass serialization record id to ex message
- return false rather than throw EndOfStreamException when provided Stream has not enough data
- don't restore the position in finally 
- limit max SZ and MD array length to Array.MaxLength, stop using LinkedList<T> as List<T> will be able to hold all elements now
- remove internal enum values that were always illegal, but needed to be handled everywhere
- Fix DebuggerDisplay
  • Loading branch information
adamsitnik authored Sep 12, 2024
1 parent 4930e1b commit 4cdbfdc
Show file tree
Hide file tree
Showing 17 changed files with 97 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace System.Formats.Nrbf;
/// <remarks>
/// ArrayInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb">[MS-NRBF] 2.4.2.1</see>.
/// </remarks>
[DebuggerDisplay("Length={Length}, {ArrayType}, rank={Rank}")]
[DebuggerDisplay("{ArrayType}, rank={Rank}")]
internal readonly struct ArrayInfo
{
internal const int MaxArrayLength = 2147483591; // Array.MaxLength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,9 @@ internal ArraySinglePrimitiveRecord(ArrayInfo arrayInfo, IReadOnlyList<T> values
public override T[] GetArray(bool allowNulls = true)
=> (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray()));

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
{
Debug.Fail("GetAllowedRecordType should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException();

private protected override void AddValue(object value)
{
Debug.Fail("AddValue should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
private protected override void AddValue(object value) => throw new InvalidOperationException();

internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
{
Expand Down Expand Up @@ -94,7 +86,7 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);
#else
byte[] bytes = ArrayPool<byte>.Shared.Rent(Math.Min(count * Unsafe.SizeOf<T>(), 256_000));
byte[] bytes = ArrayPool<byte>.Shared.Rent((int)Math.Min(requiredBytes, 256_000));

while (!resultAsBytes.IsEmpty)
{
Expand Down Expand Up @@ -159,31 +151,10 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
{
List<decimal> values = new();
#if NET
Span<byte> buffer = stackalloc byte[256];
for (int i = 0; i < count; i++)
{
int stringLength = reader.Read7BitEncodedInt();
if (!(stringLength > 0 && stringLength <= buffer.Length))
{
ThrowHelper.ThrowInvalidValue(stringLength);
}

reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength));

if (!decimal.TryParse(buffer.Slice(0, stringLength), NumberStyles.Number, CultureInfo.InvariantCulture, out decimal value))
{
ThrowHelper.ThrowInvalidFormat();
}

values.Add(value);
}
#else
for (int i = 0; i < count; i++)
{
values.Add(reader.ParseDecimal());
}
#endif
return values;
}

Expand Down Expand Up @@ -244,12 +215,14 @@ private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int coun
{
values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64()));
}
else
else if (typeof(T) == typeof(TimeSpan))
{
Debug.Assert(typeof(T) == typeof(TimeSpan));

values.Add((T)(object)new TimeSpan(reader.ReadInt64()));
}
else
{
throw new InvalidOperationException();
}
}

return values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal sealed class ArraySingleStringRecord : SZArrayRecord<string?>
public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleString;

/// <inheritdoc />
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String);
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType);

private List<SerializationRecord> Records { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
totalElementCount *= lengths[i];

if (totalElementCount > uint.MaxValue)
if (totalElementCount > ArrayInfo.MaxArrayLength)
{
ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,7 @@ private BinaryLibraryRecord(SerializationRecordId libraryId, AssemblyNameInfo li

public override SerializationRecordType RecordType => SerializationRecordType.BinaryLibrary;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on BinaryLibraryRecord");
return TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());

internal string? RawLibraryName { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
case BinaryType.Class:
info[i] = (type, ClassTypeInfo.Decode(reader, options, recordMap));
break;
default:
// Other types have no additional data.
Debug.Assert(type is BinaryType.String or BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.Object);
case BinaryType.String:
case BinaryType.StringArray:
case BinaryType.Object:
case BinaryType.ObjectArray:
// These types have no additional data.
break;
default:
throw new InvalidOperationException();
}
}

Expand Down Expand Up @@ -97,7 +101,8 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
BinaryType.PrimitiveArray => (PrimitiveArray, default),
BinaryType.Class => (NonSystemClass, default),
BinaryType.SystemClass => (SystemClass, default),
_ => (ObjectArray, default)
BinaryType.ObjectArray => (ObjectArray, default),
_ => throw new InvalidOperationException()
};
}

Expand Down Expand Up @@ -144,15 +149,15 @@ internal TypeName GetArrayTypeName(ArrayInfo arrayInfo)

TypeName elementTypeName = binaryType switch
{
BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(PrimitiveType.String),
BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String),
BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.StringPrimitiveType),
BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType),
BinaryType.Primitive => TypeNameHelpers.GetPrimitiveTypeName((PrimitiveType)additionalInfo!),
BinaryType.PrimitiveArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName((PrimitiveType)additionalInfo!),
BinaryType.Object => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.ObjectArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.SystemClass => (TypeName)additionalInfo!,
BinaryType.Class => ((ClassTypeInfo)additionalInfo!).TypeName,
_ => throw new ArgumentOutOfRangeException(paramName: nameof(binaryType), actualValue: binaryType, message: null)
_ => throw new InvalidOperationException()
};

// In general, arrayRank == 1 may have two different meanings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,5 @@ private MessageEndRecord()

public override SerializationRecordId Id => SerializationRecordId.NoId;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on MessageEndRecord");
return TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,22 @@ public static bool StartsWithPayloadHeader(Stream stream)
return false;
}

try
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
{
#if NET
Span<byte> buffer = stackalloc byte[SerializedStreamHeaderRecord.Size];
stream.ReadExactly(buffer);
#else
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
{
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
throw new EndOfStreamException();
offset += read;
stream.Position = beginning;
return false;
}
#endif
return StartsWithPayloadHeader(buffer);
}
finally
{
stream.Position = beginning;
offset += read;
}

bool result = StartsWithPayloadHeader(buffer);
stream.Position = beginning;
return result;
}

/// <summary>
Expand Down Expand Up @@ -241,7 +235,8 @@ private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap rec
SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader),
SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader),
SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader),
_ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
SerializationRecordType.SystemClassWithMembersAndTypes => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
_ => throw new InvalidOperationException()
};

recordMap.Add(record);
Expand Down Expand Up @@ -269,8 +264,8 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader
PrimitiveType.Double => new MemberPrimitiveTypedRecord<double>(reader.ReadDouble()),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(reader.ParseDecimal()),
PrimitiveType.DateTime => new MemberPrimitiveTypedRecord<DateTime>(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())),
// String is handled with a record, never on it's own
_ => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
PrimitiveType.TimeSpan => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
_ => throw new InvalidOperationException()
};
}

Expand All @@ -295,7 +290,8 @@ private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader
PrimitiveType.Double => Decode<double>(info, reader),
PrimitiveType.Decimal => Decode<decimal>(info, reader),
PrimitiveType.DateTime => Decode<DateTime>(info, reader),
_ => Decode<TimeSpan>(info, reader),
PrimitiveType.TimeSpan => Decode<TimeSpan>(info, reader),
_ => throw new InvalidOperationException()
};

static SerializationRecord Decode<T>(ArrayInfo info, BinaryReader reader) where T : unmanaged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,5 @@ internal abstract class NullsRecord : SerializationRecord

public override SerializationRecordId Id => SerializationRecordId.NoId;

public override TypeName TypeName
{
get
{
Debug.Fail($"TypeName should never be called on {GetType().Name}");
return TypeName.Parse(GetType().Name.AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(GetType().Name.AsSpan());
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ namespace System.Formats.Nrbf;
/// </remarks>
internal enum PrimitiveType : byte
{
/// <summary>
/// Used internally to express no value
/// </summary>
None = 0,
Boolean = 1,
Byte = 2,
Char = 3,
Expand All @@ -30,7 +26,19 @@ internal enum PrimitiveType : byte
DateTime = 13,
UInt16 = 14,
UInt32 = 15,
UInt64 = 16,
Null = 17,
String = 18
UInt64 = 16
// This internal enum no longer contains Null and String as they were always illegal:
// - In case of BinaryArray (NRBF 2.4.3.1):
// "If the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration
// value in AdditionalTypeInfo MUST NOT be Null (17) or String (18)."
// - In case of MemberPrimitiveTyped (NRBF 2.5.1):
// "PrimitiveTypeEnum (1 byte): A PrimitiveTypeEnumeration
// value that specifies the Primitive Type of data that is being transmitted.
// This field MUST NOT contain a value of 17 (Null) or 18 (String)."
// - In case of ArraySinglePrimitive (NRBF 2.4.3.3):
// "A PrimitiveTypeEnumeration value that identifies the Primitive Type
// of the items of the Array. The value MUST NOT be 17 (Null) or 18 (String)."
// - In case of MemberTypeInfo (NRBF 2.3.1.2):
// "When the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration
// value in AdditionalInfo MUST NOT be Null (17) or String (18)."
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ internal void Add(SerializationRecord record)
return;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id));
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id._id));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace System.Formats.Nrbf;
internal sealed class RectangularArrayRecord : ArrayRecord
{
private readonly int[] _lengths;
private readonly ICollection<object> _values;
private readonly List<object> _values;
private TypeName? _typeName;

private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
Expand All @@ -24,18 +24,8 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
MemberTypeInfo = memberTypeInfo;
_lengths = lengths;

// A List<T> can hold as many objects as an array, so for multi-dimensional arrays
// with more elements than Array.MaxLength we use LinkedList.
// Testing that many elements takes a LOT of time, so to ensure that both code paths are tested,
// we always use LinkedList code path for Debug builds.
#if DEBUG
_values = new LinkedList<object>();
#else
_values = arrayInfo.TotalElementsCount <= ArrayInfo.MaxArrayLength
? new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()))
: new LinkedList<object>();
#endif

// ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength
_values = new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()));
}

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
Expand Down Expand Up @@ -108,6 +98,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
else if (ElementType == typeof(TimeSpan)) CopyTo<TimeSpan>(_values, result);
else if (ElementType == typeof(DateTime)) CopyTo<DateTime>(_values, result);
else if (ElementType == typeof(decimal)) CopyTo<decimal>(_values, result);
else throw new InvalidOperationException();
}
else
{
Expand All @@ -116,7 +107,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)

return result;

static void CopyTo<T>(ICollection<object> list, Array array)
static void CopyTo<T>(List<object> list, Array array)
{
ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array);
ref T firstElementRef = ref Unsafe.As<byte, T>(ref arrayDataRef);
Expand Down Expand Up @@ -176,7 +167,10 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr
PrimitiveType.Int64 => sizeof(long),
PrimitiveType.UInt64 => sizeof(ulong),
PrimitiveType.Double => sizeof(double),
_ => -1
PrimitiveType.TimeSpan => sizeof(ulong),
PrimitiveType.DateTime => sizeof(ulong),
PrimitiveType.Decimal => -1, // represented as variable-length string
_ => throw new InvalidOperationException()
};

if (sizeOfSingleValue > 0)
Expand Down Expand Up @@ -215,7 +209,8 @@ private static Type MapPrimitive(PrimitiveType primitiveType)
PrimitiveType.DateTime => typeof(DateTime),
PrimitiveType.UInt16 => typeof(ushort),
PrimitiveType.UInt32 => typeof(uint),
_ => typeof(ulong)
PrimitiveType.UInt64 => typeof(ulong),
_ => throw new InvalidOperationException()
};

private static Type MapPrimitiveArray(PrimitiveType primitiveType)
Expand All @@ -235,7 +230,8 @@ private static Type MapPrimitiveArray(PrimitiveType primitiveType)
PrimitiveType.DateTime => typeof(DateTime[]),
PrimitiveType.UInt16 => typeof(ushort[]),
PrimitiveType.UInt32 => typeof(uint[]),
_ => typeof(ulong[]),
PrimitiveType.UInt64 => typeof(ulong[]),
_ => throw new InvalidOperationException()
};

private static object? GetActualValue(object value)
Expand Down
Loading

0 comments on commit 4cdbfdc

Please sign in to comment.