Skip to content

Commit

Permalink
Implement abstraction for marshalling direction in the generator APIs (
Browse files Browse the repository at this point in the history
…dotnet#78196)

* Use MarshalDirection to provide a nice abstraction for determining whether we're marshalling a parameter/return value/etc from managed to unmanaged or vice versa. This abstraction will be useful when enabling unmanaged->managed stubs as we won't need to go update every marshalling generator to correctly understand what to do.

Also rename some members from "in/out/ref" to use the direction-based names.

* PR feedback
  • Loading branch information
jkoritzinsky authored Nov 19, 2022
1 parent 13ff1c0 commit 6519ec2
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Microsoft.Interop
{
internal static class ComInterfaceGeneratorHelpers
{
public static MarshallingGeneratorFactoryKey<(TargetFramework, Version)> CreateGeneratorFactory(StubEnvironment env)
public static MarshallingGeneratorFactoryKey<(TargetFramework, Version)> CreateGeneratorFactory(StubEnvironment env, MarshalDirection direction)
{
IMarshallingGeneratorFactory generatorFactory;

Expand Down Expand Up @@ -44,7 +44,17 @@ internal static class ComInterfaceGeneratorHelpers
generatorFactory = new AttributedMarshallingModelGeneratorFactory(
generatorFactory,
elementFactory,
new AttributedMarshallingModelOptions(runtimeMarshallingDisabled, MarshalMode.ManagedToUnmanagedIn, MarshalMode.ManagedToUnmanagedRef, MarshalMode.ManagedToUnmanagedOut));
new AttributedMarshallingModelOptions(
runtimeMarshallingDisabled,
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedIn
: MarshalMode.UnmanagedToManagedOut,
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedRef
: MarshalMode.UnmanagedToManagedRef,
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedOut
: MarshalMode.UnmanagedToManagedIn));

generatorFactory = new ByValueContentsMarshalKindValidator(generatorFactory);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ internal sealed record IncrementalStubGenerationContext(
MethodSignatureDiagnosticLocations DiagnosticLocation,
SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> CallingConvention,
VirtualMethodIndexData VtableIndexData,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> GeneratorFactory,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> ManagedToUnmanagedGeneratorFactory,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> UnmanagedToManagedGeneratorFactory,
ManagedTypeInfo TypeKeyType,
ManagedTypeInfo TypeKeyOwner,
SequenceEqualImmutableArray<Diagnostic> Diagnostics);
Expand Down Expand Up @@ -301,7 +302,8 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
new MethodSignatureDiagnosticLocations(syntax),
new SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax>(callConv, SyntaxEquivalentComparer.Instance),
virtualMethodIndexData,
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged),
typeKeyType,
typeKeyOwner,
new SequenceEqualImmutableArray<Diagnostic>(generatorDiagnostics.Diagnostics.ToImmutableArray()));
Expand Down Expand Up @@ -337,16 +339,16 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateMan

// Generate stub code
var stubGenerator = new ManagedToNativeVTableMethodGenerator(
methodStub.GeneratorFactory.Key.TargetFramework,
methodStub.GeneratorFactory.Key.TargetFrameworkVersion,
methodStub.ManagedToUnmanagedGeneratorFactory.Key.TargetFramework,
methodStub.ManagedToUnmanagedGeneratorFactory.Key.TargetFrameworkVersion,
methodStub.SignatureContext.ElementTypeInformation,
methodStub.VtableIndexData.SetLastError,
methodStub.VtableIndexData.ImplicitThisParameter,
(elementInfo, ex) =>
{
diagnostics.ReportMarshallingNotSupported(methodStub.DiagnosticLocation, elementInfo, ex.NotSupportedDetails);
},
methodStub.GeneratorFactory.GeneratorFactory);
methodStub.ManagedToUnmanagedGeneratorFactory.GeneratorFactory);

BlockSyntax code = stubGenerator.GenerateStubBody(
methodStub.VtableIndexData.Index,
Expand All @@ -370,19 +372,6 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateMan
methodStub.Diagnostics.Array.AddRange(diagnostics.Diagnostics));
}

private static bool ShouldVisitNode(SyntaxNode syntaxNode)
{
// We only support C# method declarations.
if (syntaxNode.Language != LanguageNames.CSharp
|| !syntaxNode.IsKind(SyntaxKind.MethodDeclaration))
{
return false;
}

// Filter out methods with no attributes early.
return ((MethodDeclarationSyntax)syntaxNode).AttributeLists.Count > 0;
}

private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method)
{
// Verify the method has no generic types or defined implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace Microsoft.Interop
{
public readonly record struct AttributedMarshallingModelOptions(bool RuntimeMarshallingDisabled, MarshalMode InMode, MarshalMode RefMode, MarshalMode OutMode);
public readonly record struct AttributedMarshallingModelOptions(bool RuntimeMarshallingDisabled, MarshalMode ManagedToUnmanagedMode, MarshalMode BidirectionalMode, MarshalMode UnmanagedToManagedMode);

public class AttributedMarshallingModelGeneratorFactory : IMarshallingGeneratorFactory
{
Expand Down Expand Up @@ -126,7 +126,7 @@ ExpressionSyntax GetExpressionForParam(TypePositionInfo paramInfo, out bool isIn
{
if (marshallingInfo is NativeLinearCollectionMarshallingInfo collectionInfo)
{
CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(collectionInfo.Marshallers, info);
CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(collectionInfo.Marshallers, info, context);
type = marshallerData.CollectionElementType;
marshallingInfo = marshallerData.CollectionElementMarshallingInfo;
}
Expand Down Expand Up @@ -200,16 +200,15 @@ private bool ValidateRuntimeMarshallingOptions(CustomTypeMarshallerData marshall
return false;
}

private CustomTypeMarshallerData GetMarshallerDataForTypePositionInfo(CustomTypeMarshallers marshallers, TypePositionInfo info)
private CustomTypeMarshallerData GetMarshallerDataForTypePositionInfo(CustomTypeMarshallers marshallers, TypePositionInfo info, StubCodeContext context)
{
if (info.IsManagedReturnPosition)
return marshallers.GetModeOrDefault(Options.OutMode);
MarshalDirection elementDirection = MarshallerHelpers.GetMarshalDirection(info, context);

return info.RefKind switch
return elementDirection switch
{
RefKind.None or RefKind.In => marshallers.GetModeOrDefault(Options.InMode),
RefKind.Ref => marshallers.GetModeOrDefault(Options.RefMode),
RefKind.Out => marshallers.GetModeOrDefault(Options.OutMode),
MarshalDirection.ManagedToUnmanaged => marshallers.GetModeOrDefault(Options.ManagedToUnmanagedMode),
MarshalDirection.Bidirectional => marshallers.GetModeOrDefault(Options.BidirectionalMode),
MarshalDirection.UnmanagedToManaged => marshallers.GetModeOrDefault(Options.UnmanagedToManagedMode),
_ => throw new UnreachableException()
};
}
Expand All @@ -218,7 +217,7 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo
{
ValidateCustomNativeTypeMarshallingSupported(info, context, marshalInfo);

CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(marshalInfo.Marshallers, info);
CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(marshalInfo.Marshallers, info, context);
if (!ValidateRuntimeMarshallingOptions(marshallerData))
{
throw new MarshallingNotSupportedException(info, context)
Expand Down Expand Up @@ -378,9 +377,10 @@ private static TypeSyntax ReplacePlaceholderSyntaxWithUnmanagedTypeSyntax(

private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info, StubCodeContext context, NativeMarshallingAttributeInfo marshalInfo)
{
MarshalDirection elementDirection = MarshallerHelpers.GetMarshalDirection(info, context);
// Marshalling out or return parameter, but no out marshaller is specified
if ((info.RefKind == RefKind.Out || info.IsManagedReturnPosition)
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.OutMode))
if (elementDirection == MarshalDirection.UnmanagedToManaged
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.UnmanagedToManagedMode))
{
throw new MarshallingNotSupportedException(info, context)
{
Expand All @@ -389,7 +389,7 @@ private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info,
}

// Marshalling ref parameter, but no ref marshaller is specified
if (info.RefKind == RefKind.Ref && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.RefMode))
if (elementDirection == MarshalDirection.Bidirectional && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.BidirectionalMode))
{
throw new MarshallingNotSupportedException(info, context)
{
Expand All @@ -398,20 +398,8 @@ private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info,
}

// Marshalling in parameter, but no in marshaller is specified
if (info.RefKind == RefKind.In
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.InMode))
{
throw new MarshallingNotSupportedException(info, context)
{
NotSupportedDetails = SR.Format(SR.ManagedToUnmanagedMissingRequiredMarshaller, marshalInfo.EntryPointType.FullTypeName)
};
}

// Marshalling by value, but no in marshaller is specified
if (!info.IsByRef
&& !info.IsManagedReturnPosition
&& context.SingleFrameSpansNativeContext
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.InMode))
if (elementDirection == MarshalDirection.ManagedToUnmanaged
&& !marshalInfo.Marshallers.IsDefinedOrDefault(Options.ManagedToUnmanagedMode))
{
throw new MarshallingNotSupportedException(info, context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
yield break;
}

MarshalDirection elementMarshalling = MarshallerHelpers.GetMarshalDirection(info, context);

switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
break;
case StubCodeContext.Stage.Marshal:
if (info.RefKind == RefKind.Ref)
if (elementMarshalling is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional && info.IsByRef)
{
yield return ExpressionStatement(
AssignmentExpression(
Expand All @@ -82,11 +84,14 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont

break;
case StubCodeContext.Stage.Unmarshal:
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
IdentifierName(nativeIdentifier)));
if (elementMarshalling is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional && info.IsByRef)
{
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
IdentifierName(nativeIdentifier)));
}
break;
default:
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, Stu

public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
MarshalDirection elementMarshalDirection = MarshallerHelpers.GetMarshalDirection(info, context);
(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
break;
case StubCodeContext.Stage.Marshal:
// <nativeIdentifier> = (<nativeType>)(<managedIdentifier> ? _trueValue : _falseValue);
if (info.RefKind != RefKind.Out)
if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
{
yield return ExpressionStatement(
AssignmentExpression(
Expand All @@ -75,7 +76,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont

break;
case StubCodeContext.Stage.Unmarshal:
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
{
// <managedIdentifier> = <nativeIdentifier> == _trueValue;
// or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,30 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
yield break;
}

MarshalDirection elementMarshalDirection = MarshallerHelpers.GetMarshalDirection(info, context);

switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
break;
case StubCodeContext.Stage.Marshal:
if ((info.IsByRef && info.RefKind != RefKind.Out) || !context.SingleFrameSpansNativeContext)
if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
{
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
IdentifierName(managedIdentifier)));
// There's an implicit conversion from char to ushort,
// so we simplify the generated code to just pass the char value directly
if (info.IsByRef)
{
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
IdentifierName(managedIdentifier)));
}
}

break;
case StubCodeContext.Stage.Unmarshal:
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
{
yield return ExpressionStatement(
AssignmentExpression(
Expand Down
Loading

0 comments on commit 6519ec2

Please sign in to comment.