Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ internal static SyntaxToken Token(SyntaxKind kind)

internal static BlockSyntax Block(params StatementSyntax[] statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace);

internal static BlockSyntax Block(IEnumerable<StatementSyntax> statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace);

internal static ImplicitArrayCreationExpressionSyntax ImplicitArrayCreationExpression(InitializerExpressionSyntax initializerExpression) => SyntaxFactory.ImplicitArrayCreationExpression(Token(SyntaxKind.NewKeyword), Token(SyntaxKind.OpenBracketToken), default, Token(SyntaxKind.CloseBracketToken), initializerExpression);

internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? declaration, ExpressionSyntax condition, SeparatedSyntaxList<ExpressionSyntax> incrementors, StatementSyntax statement)
Expand All @@ -100,10 +102,12 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static DeclarationExpressionSyntax DeclarationExpression(TypeSyntax type, VariableDesignationSyntax designation) => SyntaxFactory.DeclarationExpression(type, designation);

internal static VariableDeclaratorSyntax VariableDeclarator(SyntaxToken identifier) => SyntaxFactory.VariableDeclarator(identifier);
internal static VariableDeclaratorSyntax VariableDeclarator(SyntaxToken identifier, EqualsValueClauseSyntax? initializer = null) => SyntaxFactory.VariableDeclarator(identifier, argumentList: null, initializer: initializer);

internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space)));

internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type, params VariableDeclaratorSyntax[] variables) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space)), SeparatedList(variables));

internal static SizeOfExpressionSyntax SizeOfExpression(TypeSyntax type) => SyntaxFactory.SizeOfExpression(Token(SyntaxKind.SizeOfKeyword), Token(SyntaxKind.OpenParenToken), type, Token(SyntaxKind.CloseParenToken));

internal static MemberAccessExpressionSyntax MemberAccessExpression(SyntaxKind kind, ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.MemberAccessExpression(kind, expression, Token(GetMemberAccessExpressionOperatorTokenKind(kind)), name);
Expand Down Expand Up @@ -190,7 +194,7 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static InitializerExpressionSyntax InitializerExpression(SyntaxKind kind, SeparatedSyntaxList<ExpressionSyntax> expressions) => SyntaxFactory.InitializerExpression(kind, OpenBrace, expressions, CloseBrace);

internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(), null);
internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type, SeparatedSyntaxList<ArgumentSyntax> arguments = default) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(arguments), null);

internal static ArrayCreationExpressionSyntax ArrayCreationExpression(ArrayTypeSyntax type, InitializerExpressionSyntax? initializer = null) => SyntaxFactory.ArrayCreationExpression(Token(SyntaxKind.NewKeyword), type, initializer);

Expand Down Expand Up @@ -295,7 +299,7 @@ internal static SyntaxList<TNode> SingletonList<TNode>(TNode node)

internal static AttributeArgumentListSyntax AttributeArgumentList(SeparatedSyntaxList<AttributeArgumentSyntax> arguments = default) => SyntaxFactory.AttributeArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken));

internal static AttributeListSyntax AttributeList() => SyntaxFactory.AttributeList(Token(SyntaxKind.OpenBracketToken), null, SeparatedList<AttributeSyntax>(), TokenWithLineFeed(SyntaxKind.CloseBracketToken));
internal static AttributeListSyntax AttributeList(params SeparatedSyntaxList<AttributeSyntax> attributes) => SyntaxFactory.AttributeList(Token(SyntaxKind.OpenBracketToken), null, attributes, TokenWithLineFeed(SyntaxKind.CloseBracketToken));

internal static SyntaxList<TNode> List<TNode>()
where TNode : SyntaxNode => SyntaxFactory.List<TNode>();
Expand All @@ -305,7 +309,7 @@ internal static SyntaxList<TNode> List<TNode>(IEnumerable<TNode> nodes)

internal static ParameterListSyntax ParameterList() => SyntaxFactory.ParameterList(Token(SyntaxKind.OpenParenToken), SeparatedList<ParameterSyntax>(), Token(SyntaxKind.CloseParenToken));

internal static ArgumentListSyntax ArgumentList(SeparatedSyntaxList<ArgumentSyntax> arguments = default) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken));
internal static ArgumentListSyntax ArgumentList(params SeparatedSyntaxList<ArgumentSyntax> arguments) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken));

internal static AssignmentExpressionSyntax AssignmentExpression(SyntaxKind kind, ExpressionSyntax left, ExpressionSyntax right) => SyntaxFactory.AssignmentExpression(kind, left, Token(GetAssignmentExpressionOperatorTokenKind(kind)).WithLeadingTrivia(Space), right);

Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Features.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public partial class Generator
private readonly bool unscopedRefAttributePredefined;
private readonly bool canUseComVariant;
private readonly bool canUseMemberFunctionCallingConvention;
private readonly bool canUseMarshalInitHandle;
private readonly INamedTypeSymbol? runtimeFeatureClass;
private readonly bool generateSupportedOSPlatformAttributes;
private readonly bool generateSupportedOSPlatformAttributesOnInterfaces; // only supported on net6.0 (https://github.com/dotnet/runtime/pull/48838)
Expand Down
125 changes: 106 additions & 19 deletions src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,54 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
.WithModifiers(TokenList(TokenWithSpace(SyntaxKind.OutKeyword)));

// HANDLE SomeLocal;
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type).AddVariables(
VariableDeclarator(typeDefHandleName.Identifier))));
leadingStatements.Add(
LocalDeclarationStatement(
VariableDeclaration(
pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type,
VariableDeclarator(typeDefHandleName.Identifier))));

ArgumentSyntax ownsHandleArgument = Argument(
NameColon(IdentifierName("ownsHandle")),
refKindKeyword: default,
LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression));

if (this.canUseMarshalInitHandle)
{
// Some = new SafeHandle(default, ownsHandle: true);
leadingStatements.Add(
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
origName,
ObjectCreationExpression(safeHandleType, [Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)), ownsHandleArgument]))));

// Marshal.InitHandle(Some, SomeLocal);
trailingStatements.Add(
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nameof(Marshal)),
IdentifierName("InitHandle")),
ArgumentList(
[
Argument(origName),
Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)),
]))));
}
else
{
// Some = new SafeHandle(SomeLocal, ownsHandle: true);
trailingStatements.Add(ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
origName,
ObjectCreationExpression(safeHandleType).AddArgumentListArguments(
Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)),
ownsHandleArgument))));
}

// Argument: &SomeLocal
arguments[paramIndex] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, typeDefHandleName));

// Some = new SafeHandle(SomeLocal, ownsHandle: true);
trailingStatements.Add(ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
origName,
ObjectCreationExpression(safeHandleType).AddArgumentListArguments(
Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)),
Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle")))))));
}
}
else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, paramAttributes, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod)
Expand Down Expand Up @@ -1108,7 +1143,46 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
&& returnTypeHandleInfo.Generator.TryGetHandleReleaseMethod(returnTypeHandleInfo.Handle, returnTypeAttributes, out string? returnReleaseMethod)
? this.RequestSafeHandle(returnReleaseMethod) : null;

if ((returnSafeHandleType is object || minorSignatureChange) && !signatureChanged)
IdentifierNameSyntax resultLocal = IdentifierName("__result");

if (this.canUseMarshalInitHandle && returnSafeHandleType is not null)
{
IdentifierNameSyntax resultSafeHandleLocal = IdentifierName("__resultSafeHandle");

// SafeHandle __resultSafeHandle = new SafeHandle(default, ownsHandle: true);
leadingStatements.Add(
LocalDeclarationStatement(
VariableDeclaration(
returnSafeHandleType,
VariableDeclarator(
resultSafeHandleLocal.Identifier,
EqualsValueClause(
ObjectCreationExpression(
returnSafeHandleType,
[
Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)),
Argument(
NameColon(IdentifierName("ownsHandle")),
refKindKeyword: default,
LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression))
]))))));

// Marshal.InitHandle(__resultSafeHandle, __result);
trailingStatements.Add(
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nameof(Marshal)),
IdentifierName("InitHandle")),
ArgumentList(
[
Argument(resultSafeHandleLocal),
Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)),
]))));
}

if ((returnSafeHandleType is not null || minorSignatureChange) && !signatureChanged)
{
// The parameter types are all the same, but we need a friendly overload with a different return type.
// Our only choice is to rename the friendly overload.
Expand Down Expand Up @@ -1145,20 +1219,33 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
})
.WithArgumentList(FixTrivia(ArgumentList().AddArguments(arguments.ToArray())));
bool hasVoidReturn = externMethodReturnType is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.VoidKeyword } };
BlockSyntax? body = Block().AddStatements(leadingStatements.ToArray());
IdentifierNameSyntax resultLocal = IdentifierName("__result");
if (returnSafeHandleType is object)
BlockSyntax? body = Block(leadingStatements);
if (returnSafeHandleType is not null)
{
//// HANDLE result = invocation();
// HANDLE result = invocation();
body = body.AddStatements(LocalDeclarationStatement(VariableDeclaration(externMethodReturnType)
.AddVariables(VariableDeclarator(resultLocal.Identifier).WithInitializer(EqualsValueClause(externInvocation)))));

body = body.AddStatements(trailingStatements.ToArray());

//// return new SafeHandle(result, ownsHandle: true);
body = body.AddStatements(ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments(
Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)),
Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))));
ReturnStatementSyntax returnStatement;
if (this.canUseMarshalInitHandle)
{
// return __resultSafeHandle;
returnStatement = ReturnStatement(IdentifierName("__resultSafeHandle"));
}
else
{
// return new SafeHandle(result, ownsHandle: true);
returnStatement = ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments(
Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)),
Argument(
NameColon(IdentifierName("ownsHandle")),
refKindKeyword: default,
LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression))));
}

body = body.AddStatements(returnStatement);
}
else if (hasVoidReturn)
{
Expand Down
21 changes: 12 additions & 9 deletions src/Microsoft.Windows.CsWin32/Generator.Handle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,29 +123,32 @@ public partial class Generator
VariableDeclarator(invalidValueFieldName.Identifier).WithInitializer(EqualsValueClause(invalidHandleIntPtr))))
.AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword)));

SyntaxToken visibilityModifier = TokenWithSpace(this.Visibility);

// public SafeHandle() : base(INVALID_HANDLE_VALUE, true)
members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier)
.AddModifiers(TokenWithSpace(this.Visibility))
.AddModifiers(visibilityModifier)
.WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments(
Argument(invalidValueFieldName),
Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression)))))
.WithBody(Block()));

// public SafeHandle(IntPtr preexistingHandle, bool ownsHandle = true) : base(INVALID_HANDLE_VALUE, ownsHandle) { this.SetHandle(preexistingHandle); }
const string preexistingHandleName = "preexistingHandle";
const string ownsHandleName = "ownsHandle";
IdentifierNameSyntax preexistingHandleName = IdentifierName("preexistingHandle");
IdentifierNameSyntax ownsHandleName = IdentifierName("ownsHandle");
members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier)
.AddModifiers(TokenWithSpace(this.Visibility))
.AddModifiers(visibilityModifier)
.AddParameterListParameters(
Parameter(Identifier(preexistingHandleName)).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))),
Parameter(Identifier(ownsHandleName)).WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)))
Parameter(preexistingHandleName.Identifier).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))),
Parameter(ownsHandleName.Identifier)
.WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)))
.WithDefault(EqualsValueClause(LiteralExpression(SyntaxKind.TrueLiteralExpression))))
.WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments(
Argument(invalidValueFieldName),
Argument(IdentifierName(ownsHandleName)))))
Argument(ownsHandleName))))
.WithBody(Block().AddStatements(
ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("SetHandle")))
.WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName(preexistingHandleName)))))))));
.WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(preexistingHandleName))))))));

// public override bool IsInvalid => this.handle.ToInt64() == 0 || this.handle.ToInt64() == -1;
ExpressionSyntax thisHandleToInt64 = InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, thisHandle, IdentifierName(nameof(IntPtr.ToInt64))), ArgumentList());
Expand Down Expand Up @@ -290,7 +293,7 @@ public partial class Generator
IEnumerable<TypeSyntax> xmlDocParameterTypes = releaseMethodSignature.ParameterTypes.Select(p => p.ToTypeSyntax(this.externSignatureTypeSettings, GeneratingElement.HelperClassMember, default).Type);

ClassDeclarationSyntax safeHandleDeclaration = ClassDeclaration(Identifier(safeHandleClassName))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.PartialKeyword))
.AddModifiers(visibilityModifier, TokenWithSpace(SyntaxKind.PartialKeyword))
.WithBaseList(BaseList(SingletonSeparatedList<BaseTypeSyntax>(SimpleBaseType(SafeHandleTypeSyntax))))
.AddMembers(members.ToArray())
.AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute))
Expand Down
Loading
Loading