Skip to content

Commit

Permalink
Add support for GetPinnableReference for the new marshaller shapes (d…
Browse files Browse the repository at this point in the history
…otnet#71412)

* Add support for static and instance GetPinnableReference in codegen.

* Add tests and collections support
  • Loading branch information
jkoritzinsky authored Jul 2, 2022
1 parent f28904a commit 0112d5f
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo

IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);

if (marshallerData.Shape.HasFlag(MarshallerShape.StatelessPinnableReference))
{
marshallingGenerator = new StaticPinnableManagedValueMarshaller(marshallingGenerator, marshallerData.MarshallerType.Syntax);
}

return marshalInfo.IsPinnableManagedType
? new PinnableManagedValueMarshaller(marshallingGenerator)
: marshallingGenerator;
Expand Down Expand Up @@ -308,6 +313,11 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(

IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);

if (marshallerData.Shape.HasFlag(MarshallerShape.StatelessPinnableReference))
{
marshallingGenerator = new StaticPinnableManagedValueMarshaller(marshallingGenerator, marshallerTypeSyntax);
}

// Elements in the collection must be blittable to use the pinnable marshaller.
return marshalInfo.IsPinnableManagedType && elementIsBlittable
? new PinnableManagedValueMarshaller(marshallingGenerator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,17 @@ public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo inf

public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context)
{
return Array.Empty<StatementSyntax>();
if (!_shape.HasFlag(MarshallerShape.StatefulPinnableReference))
yield break;

string unusedIdentifier = context.GetAdditionalIdentifier(info, "unused");
yield return FixedStatement(
VariableDeclaration(
PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))),
SingletonSeparatedList(
VariableDeclarator(unusedIdentifier)
.WithInitializer(EqualsValueClause(IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)))))),
EmptyStatement());
}

public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Microsoft.Interop
{
public sealed class StaticPinnableManagedValueMarshaller : IMarshallingGenerator
{
private readonly IMarshallingGenerator _innerMarshallingGenerator;
private readonly TypeSyntax _getPinnableReferenceType;

public StaticPinnableManagedValueMarshaller(IMarshallingGenerator innerMarshallingGenerator, TypeSyntax getPinnableReferenceType)
{
_innerMarshallingGenerator = innerMarshallingGenerator;
_getPinnableReferenceType = getPinnableReferenceType;
}

public bool IsSupported(TargetFramework target, Version version)
=> _innerMarshallingGenerator.IsSupported(target, version);

public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, StubCodeContext context)
{
if (IsPinningPathSupported(info, context))
{
if (AsNativeType(info) is PointerTypeSyntax pointerType
&& pointerType.ElementType is PredefinedTypeSyntax predefinedType
&& predefinedType.Keyword.IsKind(SyntaxKind.VoidKeyword))
{
return ValueBoundaryBehavior.NativeIdentifier;
}

// Cast to native type if it is not void*
return ValueBoundaryBehavior.CastNativeIdentifier;
}

return _innerMarshallingGenerator.GetValueBoundaryBehavior(info, context);
}

public TypeSyntax AsNativeType(TypePositionInfo info)
{
return _innerMarshallingGenerator.AsNativeType(info);
}

public SignatureBehavior GetNativeSignatureBehavior(TypePositionInfo info)
{
return _innerMarshallingGenerator.GetNativeSignatureBehavior(info);
}

public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
if (IsPinningPathSupported(info, context))
{
return GeneratePinningPath(info, context);
}

return _innerMarshallingGenerator.Generate(info, context);
}

public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context)
{
return _innerMarshallingGenerator.SupportsByValueMarshalKind(marshalKind, context);
}

public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context)
{
if (IsPinningPathSupported(info, context))
{
return false;
}

return _innerMarshallingGenerator.UsesNativeIdentifier(info, context);
}
private static bool IsPinningPathSupported(TypePositionInfo info, StubCodeContext context)
{
return context.SingleFrameSpansNativeContext && !info.IsByRef && !info.IsManagedReturnPosition;
}

private IEnumerable<StatementSyntax> GeneratePinningPath(TypePositionInfo info, StubCodeContext context)
{
if (context.CurrentStage == StubCodeContext.Stage.Pin)
{
(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);

// fixed (void* <nativeIdentifier> = &<getPinnableReferenceType>.GetPinnableReference(<managedIdentifier>))
yield return FixedStatement(
VariableDeclaration(
PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))),
SingletonSeparatedList(
VariableDeclarator(Identifier(nativeIdentifier))
.WithInitializer(EqualsValueClause(
PrefixUnaryExpression(SyntaxKind.AddressOfExpression,
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
_getPinnableReferenceType,
IdentifierName(ShapeMemberNames.GetPinnableReference)),
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(managedIdentifier))))))
))
)
),
EmptyStatement());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ public partial class Collections
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")]
public static partial int Sum([MarshalUsing(typeof(ListMarshaller<,>))] List<int> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_values")]
public static partial int DoubleValues([MarshalUsing(typeof(ListMarshallerWithPinning<,>))] List<BlittableIntWrapper> values, int length);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")]
public static partial int SumInArray([MarshalUsing(typeof(ListMarshaller<,>))] in List<int> values, int numValues);

Expand Down Expand Up @@ -48,6 +51,15 @@ public void BlittableElementColllectionMarshalledToNativeAsExpected()
Assert.Equal(list.Sum(), NativeExportsNE.Collections.Sum(list, list.Count));
}

[Fact]
public void BlittableElementColllectionMarshalledToNativeWithPinningAsExpected()
{
var data = new List<int> { 1, 5, 79, 165, 32, 3 };
var list = data.Select(i => new BlittableIntWrapper { i = i }).ToList();
NativeExportsNE.Collections.DoubleValues(list, list.Count);
Assert.Equal(data.Select(i => i * 2), list.Select(wrapper => wrapper.i));
}

[Fact]
public void NullBlittableElementColllectionMarshalledToNativeAsExpected()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ public static partial void NegateBools(
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_int_ref")]
public static partial IntWrapper DoubleIntRef(IntWrapper pInt);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_int_ref")]
public static partial IntWrapperWithoutGetPinnableReference DoubleIntRef(IntWrapperWithoutGetPinnableReference pInt);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "return_zero")]
[return: MarshalUsing(typeof(IntGuaranteedUnmarshal))]
public static partial int GuaranteedUnmarshal([MarshalUsing(typeof(ExceptionOnUnmarshal))] out int ret);
Expand Down Expand Up @@ -78,6 +81,11 @@ public static partial void NegateBools(
[return: MarshalAs(UnmanagedType.U1)]
public static partial bool AndBoolsRef([MarshalUsing(typeof(BoolStructMarshallerStateful))] in BoolStruct boolStruct);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_int_ref")]
public static partial IntWrapperWithoutGetPinnableReference DoubleIntRef([MarshalUsing(typeof(IntWrapperWithoutGetPinnableReferenceStatefulMarshaller))] IntWrapperWithoutGetPinnableReference pInt);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_int_ref")]
public static partial IntWrapperWithoutGetPinnableReference DoubleIntRefNoAlloc([MarshalUsing(typeof(IntWrapperWithoutGetPinnableReferenceStatefulNoAllocMarshaller))] IntWrapperWithoutGetPinnableReference pInt);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_int_ref")]
[return: MarshalUsing(typeof(IntWrapperMarshallerStateful))]
Expand Down Expand Up @@ -149,7 +157,7 @@ public void NonBlittableStructWithoutAllocation()
}

[Fact]
public void GetPinnableReferenceMarshalling()
public void ManagedTypeGetPinnableReferenceMarshalling()
{
int originalValue = 42;
var wrapper = new IntWrapper { i = originalValue };
Expand All @@ -160,6 +168,18 @@ public void GetPinnableReferenceMarshalling()
Assert.Equal(originalValue * 2, retVal.i);
}

[Fact]
public void MarshallerStaticGetPinnableReferenceMarshalling()
{
int originalValue = 42;
var wrapper = new IntWrapperWithoutGetPinnableReference { i = originalValue };

var retVal = NativeExportsNE.Stateless.DoubleIntRef(wrapper);

Assert.Equal(originalValue * 2, wrapper.i);
Assert.Equal(originalValue * 2, retVal.i);
}

[Fact]
public void NonBlittableStructRef()
{
Expand Down Expand Up @@ -310,6 +330,30 @@ public void NonBlittableType_Stateful_Marshalling_Free()
Assert.Equal(originalValue * 2, retVal.i);
}

[Fact]
public void StatefulMarshallerStaticGetPinnableReferenceMarshalling()
{
int originalValue = 42;
var wrapper = new IntWrapperWithoutGetPinnableReference { i = originalValue };

var retVal = NativeExportsNE.Stateful.DoubleIntRef(wrapper);

Assert.Equal(originalValue * 2, wrapper.i);
Assert.Equal(originalValue * 2, retVal.i);
}

[Fact]
public void StatefulMarshallerInstanceGetPinnableReferenceMarshalling()
{
int originalValue = 42;
var wrapper = new IntWrapperWithoutGetPinnableReference { i = originalValue };

var retVal = NativeExportsNE.Stateful.DoubleIntRefNoAlloc(wrapper);

Assert.Equal(originalValue * 2, wrapper.i);
Assert.Equal(originalValue * 2, retVal.i);
}

private static string ReverseChars(string value)
{
if (value == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,15 @@ public struct Native { }
public const int BufferSize = 0x100;
public static Native ConvertToUnmanaged(S s, System.Span<byte> buffer) => default;
}
";

public static string InPinnable = @"
[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(Marshaller))]
public static unsafe class Marshaller
{
public static byte* ConvertToUnmanaged(S s) => default;
public static ref byte GetPinnableReference(S s) => throw null;
}
";
private static string Out = @"
[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedOut, typeof(Marshaller))]
Expand Down Expand Up @@ -815,6 +824,9 @@ public struct Native { }
public static string StackallocByValueInParameter => BasicParameterByValue("S")
+ NonBlittableUserDefinedType()
+ InBuffer;
public static string PinByValueInParameter => BasicParameterByValue("S")
+ NonBlittableUserDefinedType()
+ InPinnable;

public static string StackallocParametersAndModifiersNoRef = BasicParametersAndModifiersNoRef("S")
+ NonBlittableUserDefinedType()
Expand Down Expand Up @@ -843,6 +855,34 @@ public void FromManaged(S s) {}
public Native ToUnmanaged() => default;
}
}
";

public static string InStatelessPinnable = @"
[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(M))]
public static class Marshaller
{
public unsafe struct M
{
public void FromManaged(S s) {}
public byte* ToUnmanaged() => default;
public static ref byte GetPinnableReference(S s) => throw null;
}
}
";

public static string InPinnable = @"
[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(M))]
public static class Marshaller
{
public unsafe struct M
{
public void FromManaged(S s) {}
public byte* ToUnmanaged() => default;
public ref byte GetPinnableReference() => throw null;
}
}
";

private static string InBuffer = @"
Expand Down Expand Up @@ -1014,6 +1054,12 @@ public void FromUnmanaged(Native n) {}
public static string StackallocByValueInParameter => BasicParameterByValue("S")
+ NonBlittableUserDefinedType()
+ InBuffer;
public static string PinByValueInParameter => BasicParameterByValue("S")
+ NonBlittableUserDefinedType()
+ InStatelessPinnable;
public static string MarshallerPinByValueInParameter => BasicParameterByValue("S")
+ NonBlittableUserDefinedType()
+ InPinnable;

public static string StackallocParametersAndModifiersNoRef = BasicParametersAndModifiersNoRef("S")
+ NonBlittableUserDefinedType()
Expand Down Expand Up @@ -1418,6 +1464,17 @@ static unsafe class Marshaller<T, [ElementUnmanagedType] TUnmanagedElement>
public static System.ReadOnlySpan<T> GetManagedValuesSource(TestCollection<T> managed) => throw null;
public static System.Span<TUnmanagedElement> GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null;
}
";
public const string InPinnable = @"
[CustomMarshaller(typeof(TestCollection<>), Scenario.ManagedToUnmanagedIn, typeof(Marshaller<,>))]
static unsafe class Marshaller<T, [ElementUnmanagedType] TUnmanagedElement>
{
public static byte* AllocateContainerForUnmanagedElements(TestCollection<T> managed, out int numElements) => throw null;
public static System.ReadOnlySpan<T> GetManagedValuesSource(TestCollection<T> managed) => throw null;
public static System.Span<TUnmanagedElement> GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null;
public static ref byte GetPinnableReference(TestCollection<T> managed) => throw null;
}
";
public const string InBuffer = @"
[CustomMarshaller(typeof(TestCollection<>), Scenario.ManagedToUnmanagedIn, typeof(Marshaller<,>))]
Expand Down Expand Up @@ -1475,6 +1532,11 @@ public static string ByValue(string elementType) => BasicParameterByValue($"Test
+ TestCollection()
+ In;

public static string ByValueWithPinning<T>() => ByValueWithPinning(typeof(T).ToString());
public static string ByValueWithPinning(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>", DisableRuntimeMarshalling)
+ TestCollection()
+ InPinnable;

public static string ByValueCallerAllocatedBuffer<T>() => ByValueCallerAllocatedBuffer(typeof(T).ToString());
public static string ByValueCallerAllocatedBuffer(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>", DisableRuntimeMarshalling)
+ TestCollection()
Expand Down
Loading

0 comments on commit 0112d5f

Please sign in to comment.