Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 75 additions & 58 deletions src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi

bool improvePointersToSpansAndRefs = this.canUseSpan;
FriendlyMethodBookkeeping bookkeeping = new();
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs, omitOptionalParams: false, bookkeeping))
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs, omitOptionalParams: false, promoteUnconvertibleHandles: true, bookkeeping))
{
yield return method;
}

if (this.Options.FriendlyOverloads.IncludePointerOverloads && improvePointersToSpansAndRefs && bookkeeping.NumSpanByteParameters > 0)
{
// If we could use Span and _did_ use span Span and the pointer overloads were requested, then Generate overloads that use pointer types instead of Span<byte>/ReadOnlySpan<byte>.
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs: false, omitOptionalParams: false))
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs: false, omitOptionalParams: false, promoteUnconvertibleHandles: true))
{
yield return method;
}
Expand All @@ -114,6 +114,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
bool avoidWinmdRootAlias,
bool improvePointersToSpansAndRefs,
bool omitOptionalParams,
bool promoteUnconvertibleHandles,
FriendlyMethodBookkeeping? bookkeeping = null)
{
#pragma warning disable SA1114 // Parameter list should follow declaration
Expand Down Expand Up @@ -152,6 +153,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
int numSpanByteParameters = 0;
SyntaxToken friendlyMethodName = externMethodDeclaration.Identifier;
bool emulateMemberFunctionCallConv = friendlyMethodName.ValueText.EndsWith(EmulateMemberFunctionCallConvSuffix);
bool hasUnconvertibleHandles = false;

foreach (ParameterHandle paramHandle in methodDefinition.GetParameters())
{
Expand Down Expand Up @@ -448,68 +450,74 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, paramAttributes, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod)
&& !(this.TryGetTypeDefFieldType(parameterHandleTypeInfo, out TypeHandleInfo? fieldType) && !this.IsSafeHandleCompatibleTypeDefFieldType(fieldType)))
{
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");
signatureChanged = true;
var isUnconvertibelHandle = this.RequestSafeHandle(releaseMethod) is null;
hasUnconvertibleHandles |= isUnconvertibelHandle;

IdentifierNameSyntax refAddedName = IdentifierName(externParam.Identifier.ValueText + "AddRef");
if (!isUnconvertibelHandle || promoteUnconvertibleHandles)
{
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");
signatureChanged = true;

// bool hParamNameAddRef = false;
leadingOutsideTryStatements.Add(LocalDeclarationStatement(
VariableDeclaration(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)), [VariableDeclarator(refAddedName.Identifier, EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression)))])));
IdentifierNameSyntax refAddedName = IdentifierName(externParam.Identifier.ValueText + "AddRef");

// HANDLE hTemplateFileLocal;
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(externParam.Type, [VariableDeclarator(typeDefHandleName.Identifier)])));
// bool hParamNameAddRef = false;
leadingOutsideTryStatements.Add(LocalDeclarationStatement(
VariableDeclaration(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)), [VariableDeclarator(refAddedName.Identifier, EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression)))])));

// throw new ArgumentNullException(nameof(hTemplateFile));
StatementSyntax nullHandleStatement = ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentNullException))).WithArgumentList(ArgumentList(Argument(NameOfExpression(IdentifierName(externParam.Identifier.ValueText))))));
if (isOptional)
{
// (HANDLE)new IntPtr(-1);
HashSet<IntPtr> invalidValues = this.GetInvalidHandleValues(parameterHandleTypeInfo.Handle);
IntPtr invalidValue = invalidValues.Count > 0 ? GetPreferredInvalidHandleValue(invalidValues) : IntPtr.Zero;
ExpressionSyntax invalidExpression = CastExpression(externParam.Type, IntPtrExpr(invalidValue));
// HANDLE hTemplateFileLocal;
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(externParam.Type, [VariableDeclarator(typeDefHandleName.Identifier)])));

// hTemplateFileLocal = invalid-handle-value;
nullHandleStatement = ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, invalidExpression));
}
// throw new ArgumentNullException(nameof(hTemplateFile));
StatementSyntax nullHandleStatement = ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentNullException))).WithArgumentList(ArgumentList(Argument(NameOfExpression(IdentifierName(externParam.Identifier.ValueText))))));
if (isOptional)
{
// (HANDLE)new IntPtr(-1);
HashSet<IntPtr> invalidValues = this.GetInvalidHandleValues(parameterHandleTypeInfo.Handle);
IntPtr invalidValue = invalidValues.Count > 0 ? GetPreferredInvalidHandleValue(invalidValues) : IntPtr.Zero;
ExpressionSyntax invalidExpression = CastExpression(externParam.Type, IntPtrExpr(invalidValue));

// if (hTemplateFile is object)
leadingStatements.Add(IfStatement(
BinaryExpression(SyntaxKind.IsExpression, origName, PredefinedType(Token(SyntaxKind.ObjectKeyword))),
Block(
//// hTemplateFile.DangerousAddRef(ref hTemplateFileAddRef);
ExpressionStatement(InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
origName,
IdentifierName(nameof(SafeHandle.DangerousAddRef))),
[Argument(refAddedName).WithRefKindKeyword(TokenWithSpace(SyntaxKind.RefKeyword))])),
//// hTemplateFileLocal = (HANDLE)hTemplateFile.DangerousGetHandle();
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
typeDefHandleName,
CastExpression(
externParam.Type.WithoutTrailingTrivia(),
InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousGetHandle))))))
.WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken)))),
//// else hTemplateFileLocal = default;
ElseClause(nullHandleStatement)));

// if (hTemplateFileAddRef)
// hTemplateFile.DangerousRelease();
finallyStatements.Add(
IfStatement(
refAddedName,
ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousRelease))))))
.WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken)));

// Accept the SafeHandle instead.
parameters[paramIndex] = externParam
.WithType(IdentifierName(nameof(SafeHandle)).WithTrailingTrivia(TriviaList(Space)));
// hTemplateFileLocal = invalid-handle-value;
nullHandleStatement = ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, invalidExpression));
}

// if (hTemplateFile is object)
leadingStatements.Add(IfStatement(
BinaryExpression(SyntaxKind.IsExpression, origName, PredefinedType(Token(SyntaxKind.ObjectKeyword))),
Block(
//// hTemplateFile.DangerousAddRef(ref hTemplateFileAddRef);
ExpressionStatement(InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
origName,
IdentifierName(nameof(SafeHandle.DangerousAddRef))),
[Argument(refAddedName).WithRefKindKeyword(TokenWithSpace(SyntaxKind.RefKeyword))])),
//// hTemplateFileLocal = (HANDLE)hTemplateFile.DangerousGetHandle();
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
typeDefHandleName,
CastExpression(
externParam.Type.WithoutTrailingTrivia(),
InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousGetHandle))))))
.WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken)))),
//// else hTemplateFileLocal = default;
ElseClause(nullHandleStatement)));

// if (hTemplateFileAddRef)
// hTemplateFile.DangerousRelease();
finallyStatements.Add(
IfStatement(
refAddedName,
ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousRelease))))))
.WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken)));

// Accept the SafeHandle instead.
parameters[paramIndex] = externParam
.WithType(IdentifierName(nameof(SafeHandle)).WithTrailingTrivia(TriviaList(Space)));

// hParamNameLocal;
arguments[paramIndex] = Argument(typeDefHandleName);
// hParamNameLocal;
arguments[paramIndex] = Argument(typeDefHandleName);
}
}
else if ((externParam.Type is PointerTypeSyntax { ElementType: TypeSyntax ptrElementType }
&& (!IsVoid(ptrElementType) || (improvePointersToSpansAndRefs && isArray))
Expand Down Expand Up @@ -1387,7 +1395,16 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
if (numOptionalParams > 0 && !omitOptionalParams && improvePointersToSpansAndRefs)
{
// Generate overloads for optional parameters.
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs, omitOptionalParams: true))
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs, omitOptionalParams: true, promoteUnconvertibleHandles))
{
yield return method;
}
}

if (promoteUnconvertibleHandles && hasUnconvertibleHandles && !omitOptionalParams)
{
// Generate overloads with raw unconvertible handles in the signature.
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs, omitOptionalParams: false, promoteUnconvertibleHandles: false))
{
yield return method;
}
Expand Down
5 changes: 4 additions & 1 deletion test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ public async Task DelegatesGetStructsGenerated()
// WlanCloseHandle accepts an additional reserved parameter. We can still generate safe hanlde for WlanOpenHandle then
["WlanOpenHandle", "WlanOpenHandle", "uint dwClientVersion, out uint pdwNegotiatedVersion, out global::Windows.Win32.WlanCloseHandleSafeHandle phClientHandle"],
// Has an out reference of a handle that cannot be trivially converted to a SafeHandle
["Windows.Win32.NetworkManagement.WindowsFilteringPlatform.FwpmFilterCreateEnumHandle0", "FwpmFilterCreateEnumHandle0", "SafeHandle engineHandle, [Optional] winmdroot.NetworkManagement.WindowsFilteringPlatform.FWPM_FILTER_ENUM_TEMPLATE0? enumTemplate, out winmdroot.NetworkManagement.WindowsFilteringPlatform.FWPM_FILTER_ENUM_HANDLE enumHandle"]
["Windows.Win32.NetworkManagement.WindowsFilteringPlatform.FwpmFilterCreateEnumHandle0", "FwpmFilterCreateEnumHandle0", "SafeHandle engineHandle, [Optional] winmdroot.NetworkManagement.WindowsFilteringPlatform.FWPM_FILTER_ENUM_TEMPLATE0? enumTemplate, out winmdroot.NetworkManagement.WindowsFilteringPlatform.FWPM_FILTER_ENUM_HANDLE enumHandle"],
// Accepts a handle that cannot be trivially represented as a SafeHandle. Verify that overloads with both SafeHandle and raw handle exist
["Windows.Win32.NetworkManagement.WindowsFilteringPlatform.FwpmFilterEnum0", "FwpmFilterEnum0", "SafeHandle engineHandle, SafeHandle enumHandle, uint numEntriesRequested, out winmdroot.NetworkManagement.WindowsFilteringPlatform.FWPM_FILTER0** entries, out uint numEntriesReturned"],
["Windows.Win32.NetworkManagement.WindowsFilteringPlatform.FwpmFilterEnum0", "FwpmFilterEnum0", "SafeHandle engineHandle, winmdroot.NetworkManagement.WindowsFilteringPlatform.FWPM_FILTER_ENUM_HANDLE enumHandle, uint numEntriesRequested, out winmdroot.NetworkManagement.WindowsFilteringPlatform.FWPM_FILTER0** entries, out uint numEntriesReturned"],
];

[Theory]
Expand Down
Loading