Skip to content
Merged
7 changes: 3 additions & 4 deletions src/OpenAPI.WebApiGenerator/ApiGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@ private static void GenerateCode(SourceProductionContext context,
var responses = operation.Responses ??
throw new InvalidOperationException(
$"No responses defined for operation at {openApiOperationVisitor.Pointer}");
var responseBodyGenerators = responses.Select(pair =>
var responseBodyGenerators = responses.Select(content =>
{
var response = pair.Value;
var responseStatusCodePattern = pair.Key.ToPascalCase();
var response = content.Value;
var openApiResponseVisitor = openApiOperationVisitor.Visit(response);

var responseContent =
Expand All @@ -186,7 +185,7 @@ private static void GenerateCode(SourceProductionContext context,
}).ToList() ?? [];

return new ResponseContentGenerator(
responseStatusCodePattern,
content,
responseBodyGenerators,
responseHeaderGenerators);
}).ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public sealed class {{ClassName}}
public Uri? OpenApiSpecificationUri { get; init; }{{(authGenerator.HasSecuritySchemes ?
"""

/// <summary>
/// Security scheme options
/// </summary>
internal SecuritySchemeOptions SecuritySchemeOptions { get; set; } = new();
""" : "")}}
}
Expand Down
62 changes: 57 additions & 5 deletions src/OpenAPI.WebApiGenerator/CodeGeneration/AuthGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata;
using Microsoft.OpenApi;
using OpenAPI.WebApiGenerator.Extensions;

Expand Down Expand Up @@ -38,6 +37,9 @@ public AuthGenerator(OpenApiDocument openApiDocument)

namespace {{@namespace}};

/// <summary>
/// Defines security schemes that can be used by the operations
/// </summary>
internal static class SecuritySchemes
{{{_securitySchemes.AggregateToString(pair =>
{
Expand All @@ -47,10 +49,10 @@ internal static class SecuritySchemes
return scheme.Type == null ? string.Empty :
$$"""
internal const string {{className}}Key = "{{pair.Key}}";
{{scheme.Description.AsComment("summary", "para").Indent(4)}}
internal static class {{className}}
{{{new []
{
GenerateConst(nameof(scheme.Description), scheme.Description),
GenerateConst(nameof(scheme.Type), scheme.Type?.GetDisplayName()),
GenerateConst(nameof(scheme.Scheme), scheme.Scheme),
GenerateConst(nameof(scheme.BearerFormat), scheme.BearerFormat),
Expand All @@ -59,7 +61,7 @@ internal static class {{className}}
$"internal const bool {nameof(scheme.Deprecated)} = {scheme.Deprecated.ToString().ToLowerInvariant()};",
GenerateFlowsObject(nameof(scheme.Flows), scheme.Flows)
}.RemoveEmptyLines().AggregateToString().Indent(8)}}
}
}
""";
})}}
}
Expand Down Expand Up @@ -218,13 +220,18 @@ internal static class {{className}}

namespace {{@namespace}};

/// <summary>
/// Base class for handling security requirements for an operation
/// </summary>
internal abstract class BaseSecurityRequirementsFilter(WebApiConfiguration configuration) : IEndpointFilter
{
protected abstract SecurityRequirements Requirements { get; }
protected WebApiConfiguration Configuration { get; } = configuration;

protected abstract void HandleForbidden(HttpResponse response);
protected abstract void HandleUnauthorized(HttpResponse response);

/// <inheritdoc/>
public async ValueTask<object?> InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next)
{
var httpContext = context.HttpContext;
Expand Down Expand Up @@ -297,7 +304,15 @@ private static bool ClaimContainsScopes(ClaimsPrincipal? principal, SecuritySche
return scopes.All(scope => foundScopes.Contains(scope));
}

/// <summary>
/// A declaration of which security mechanisms can be used for an operation.
/// The list of values includes alternative Security Requirement Objects that can be used. Only one of the Security Requirement Objects need to be satisfied to authorize a request. To make security optional, an empty security requirement can be included in the list.
/// </summary>
internal class SecurityRequirements : List<SecurityRequirement>, IAuthorizationRequirement;

/// <summary>
/// Lists the required security schemes to execute an operation.
/// </summary>
internal class SecurityRequirement : Dictionary<string, string[]>;
}
#nullable restore
Expand All @@ -318,6 +333,9 @@ internal string GenerateAuthFilters(OpenApiOperation operation, ParameterGenerat
_requestFilters.Add(operation, [securityRequirementsFilterClassName]);
return
$$"""
/// <summary>
/// Filter for handling security requirements
/// </summary>
internal sealed class {{securityRequirementsFilterClassName}} : IEndpointFilter
{
public ValueTask<object?> InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next)
Expand All @@ -338,6 +356,9 @@ internal sealed class {{securityRequirementsFilterClassName}} : IEndpointFilter
: [securityRequirementsFilterClassName]);
return (hasSecuritySchemeParameters ?
$$"""
/// <summary>
/// Filter for extracting security scheme parameters
/// </summary>
internal sealed class {{securitySchemeParameterFilterClassName}} : IEndpointFilter
{
public ValueTask<object?> InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next)
Expand All @@ -358,6 +379,9 @@ internal sealed class {{securitySchemeParameterFilterClassName}} : IEndpointFilt
""" : string.Empty) +
$$"""

/// <summary>
/// Filter for handling security requirements
/// </summary>
internal sealed class {{securityRequirementsFilterClassName}}(Operation operation, WebApiConfiguration configuration) : BaseSecurityRequirementsFilter(configuration)
{
protected override SecurityRequirements Requirements { get; } = new()
Expand All @@ -372,8 +396,8 @@ internal sealed class {{securityRequirementsFilterClassName}}(Operation operatio
""")))}}
};

protected override void HandleUnauthorized(HttpResponse response) => operation.Validate(operation.HandleUnauthorized(), configuration).WriteTo(response);
protected override void HandleForbidden(HttpResponse response) => operation.Validate(operation.HandleForbidden(), configuration).WriteTo(response);
protected override void HandleUnauthorized(HttpResponse response) => operation.Validate(operation.HandleUnauthorized(), Configuration).WriteTo(response);
protected override void HandleForbidden(HttpResponse response) => operation.Validate(operation.HandleForbidden(), Configuration).WriteTo(response);
}
""";
}
Expand All @@ -389,12 +413,21 @@ internal sealed class {{securityRequirementsFilterClassName}}(Operation operatio
#nullable enable
namespace {{@namespace}};

/// <summary>
/// Options for security schemes
/// </summary>
internal sealed class SecuritySchemeOptions
{{{_securitySchemes.AggregateToString(pair =>
$$"""
{{pair.Value.Description.AsComment("summary", "para")}}
public SecuritySchemeOption {{pair.Key.ToPascalCase()}} { get; init; } = new();
""").Indent(4)}}

/// <summary>
/// Get scope options
/// </summary>
/// <param name="scheme">Name of security scheme</param>
/// <returns>Scope options for the security scheme</returns>
internal ScopeOptions GetScopeOptions(string scheme) =>
scheme switch
{{{_securitySchemes.AggregateToString(pair =>
Expand All @@ -404,20 +437,39 @@ internal ScopeOptions GetScopeOptions(string scheme) =>
_ => throw new InvalidOperationException($"Scheme {scheme} is unknown")
};

/// <summary>
/// Security scheme option
/// </summary>
internal sealed class SecuritySchemeOption
{
/// <summary>
/// Scope options
/// </summary>
public ScopeOptions Scope {get; init; } = new()
{
Claim = "scope",
Format = ScopeOptions.ClaimFormat.SpaceDelimited
};
}

/// <summary>
/// Scope options
/// </summary>
internal sealed class ScopeOptions
{
/// <summary>
/// Name of the claim
/// </summary>
public required string Claim { get; init; }

/// <summary>
/// Claim format
/// </summary>
public required ClaimFormat Format { get; init; }

/// <summary>
/// Claim formats
/// </summary>
internal enum ClaimFormat
{
SpaceDelimited,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ internal SourceCode GenerateHttpRequestExtensionsClass() =>

namespace {{{@namespace}}};

/// <summary>
/// Extension methods for http request objects
/// </summary>
internal static class {{{HttpRequestExtensionsClassName}}}
{
private const string ParameterValueParserVersion = "{{{openApiVersion.GetParameterVersion()}}}";
Expand All @@ -64,11 +67,10 @@ private static IParameter GetParameter(string parameterSpecificationAsJson) =>
/// <summary>
/// Binds an http parameter to a json type
/// </summary>
/// <param name="request"></param>
/// <param name="request">Request to bind from</param>
/// <param name="parameterSpecificationAsJson">OpenAPI parameter specification formatted as json</param>
/// <typeparam name="T">The type to bind</typeparam>
/// <returns>The bound instance</returns>
/// <exception cref="BadHttpRequestException"></exception>
internal static T Bind<T>(this HttpRequest request,
string parameterSpecificationAsJson)
where T : struct, IJsonValue<T>
Expand All @@ -82,8 +84,15 @@ _ when TryParse<T>(request, parameter, out var value) => value.Value,
};
}

/// <summary>
/// Binds an http body to a json type
/// </summary>
/// <param name="request">Request to bind from</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <typeparam name="T">The type to bind</typeparam>
/// <returns>An awaitable task to the bound instance</returns>
internal static async Task<T> BindBodyAsync<T>(this HttpRequest request,
CancellationToken cancellationToken)
CancellationToken cancellationToken)
where T : struct, IJsonValue<T>
{
var document = await JsonDocument.ParseAsync(request.Body,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,23 @@ internal SourceCode GenerateHttpResponseExtensionsClass() =>
using JsonObject = System.Text.Json.Nodes.JsonObject;

namespace {{{@namespace}}};


/// <summary>
/// Extension methods for http response objects
/// </summary>
internal static class {{{HttpResponseExtensionsClassName}}}
{
private static readonly ConcurrentDictionary<string, IParameterValueParser> ParserCache = new();
private const string ParameterValueParserVersion = "{{{openApiSpecVersion.GetParameterVersion()}}}";

/// <summary>
/// Write header to a response object
/// </summary>
/// <param name="response">The response object to write the header to</param>
/// <param name="headerSpecificationAsJson">OpenAPI specification for the header</param>
/// <param name="name">The header name</param>
/// <param name="value">The header value</param>
/// <typeparam name="TValue">The type of the header</typeparam>
internal static void WriteResponseHeader<TValue>(this HttpResponse response,
string headerSpecificationAsJson,
string name,
Expand All @@ -55,6 +66,12 @@ internal static void WriteResponseHeader<TValue>(this HttpResponse response,
response.Headers[name] = serializedValue;
}

/// <summary>
/// Write body to a response object
/// </summary>
/// <param name="response">The response object to write the body to</param>
/// <param name="value">The value of the body</param>
/// <typeparam name="TValue">The type of the body</typeparam>
internal static void WriteResponseBody<TValue>(this HttpResponse response, TValue value)
where TValue : struct, IJsonValue<TValue>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

internal sealed class JsonValidationExceptionGenerator(string @namespace)
{
private const string ClassName = "JsonValidationException";
internal const string ClassName = "JsonValidationException";
internal string CreateThrowJsonValidationExceptionInvocation(
string message,
string validationResultVariableName)
Expand All @@ -21,14 +21,23 @@ internal SourceCode GenerateJsonValidationExceptionClass() =>

namespace {{@namespace}};

/// <summary>
/// Exception thrown when validation of json objects fail
/// </summary>
internal sealed class {{ClassName}} : Exception
{
/// <summary>
/// Create json validation exception
/// </summary>
internal {{ClassName}}(string message, ImmutableList<ValidationResult> validationResult) : base(
GetValidationMessage(message, validationResult))
GetValidationMessage(message, validationResult))
{
ValidationResult = validationResult;
}

/// <summary>
/// The validation result
/// </summary>
internal ImmutableList<ValidationResult> ValidationResult { get; }

private static string GetValidationMessage(string message, ImmutableList<ValidationResult> validationResult)
Expand Down
Loading
Loading