Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// <auto-generated/>
#nullable restore

namespace TestLibrary;

internal interface IOrchestrator
{
public Task<int> Execute();
}
2 changes: 1 addition & 1 deletion AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace TestLibrary;

internal class Orchestrator
internal class Orchestrator : IOrchestrator
{
private readonly TestLibrary.A a;
private readonly TestLibrary.B b;
Expand Down
6 changes: 4 additions & 2 deletions AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
[Test]
public async Task OneInterface() {
var source = await ReadCSharpFile<OrchestratorSpec>(true);
var generated = await ReadCSharpFile<Orchestrator>(false);
var generatedClass = await ReadCSharpFile<Orchestrator>(false);
var generatedInterface = await ReadCSharpFile<IOrchestrator>(false);

await new VerifyCS.Test
{
Expand All @@ -35,7 +36,8 @@
Sources = { source },
GeneratedSources =
{
(typeof(Main), "Orchestrator.generated.cs", SourceText.From(generated, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
(typeof(Main), "Orchestrator.generated.cs", SourceText.From(generatedClass, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
(typeof(Main), "IOrchestrator.generated.cs", SourceText.From(generatedInterface, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
},
},
}.RunAsync();
Expand All @@ -49,19 +51,19 @@
private static async Task<string> ReadCSharpFile<T>(bool isTestLibrary) {
var currentDirectory = GetCurrentDirectory();

var targetDirectory = isTestLibrary ? GetTestLibraryDirectory(currentDirectory) : currentDirectory;

Check warning on line 54 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Possible null reference argument for parameter 'currentDirectory' in 'DirectoryInfo Tests.GetTestLibraryDirectory(DirectoryInfo currentDirectory)'.

Check warning on line 54 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Possible null reference argument for parameter 'currentDirectory' in 'DirectoryInfo Tests.GetTestLibraryDirectory(DirectoryInfo currentDirectory)'.

var searchPattern = $"{typeof(T).Name}*.cs";
var file = targetDirectory.GetFiles(searchPattern).First();

Check warning on line 57 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Dereference of a possibly null reference.

Check warning on line 57 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Dereference of a possibly null reference.

using var fileReader = new StreamReader(file.OpenRead());
return await fileReader.ReadToEndAsync();
}

private static DirectoryInfo? GetCurrentDirectory() {
return Directory.GetParent(Directory.GetParent(AppDomain.CurrentDomain.BaseDirectory).Parent.Parent.FullName);

Check warning on line 64 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Dereference of a possibly null reference.

Check warning on line 64 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Dereference of a possibly null reference.

Check warning on line 64 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Dereference of a possibly null reference.

Check warning on line 64 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Dereference of a possibly null reference.

Check warning on line 64 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Dereference of a possibly null reference.
}

private static DirectoryInfo GetTestLibraryDirectory(DirectoryInfo currentDirectory) {
return currentDirectory.Parent.GetDirectories("TestLibrary").First();

Check warning on line 68 in AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

View workflow job for this annotation

GitHub Actions / Unit tests (9.0.x)

Dereference of a possibly null reference.
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.1</TargetFramework>
<TargetFramework>netstandard2.0</TargetFramework>
<IsRoslynComponent>true</IsRoslynComponent>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<Title>Async Task Orchestrator Generator</Title>
Expand Down
16 changes: 7 additions & 9 deletions AsyncTaskOrchestratorGenerator/Main.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;

Expand All @@ -26,15 +23,16 @@ private static bool IsSyntaxTargetForGeneration(SyntaxNode syntaxNode, Cancellat
return syntaxNode is TypeDeclarationSyntax;
}

private static (INamedTypeSymbol, SemanticModel) GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) {

return (context.TargetSymbol as INamedTypeSymbol, context.SemanticModel);
private static INamedTypeSymbol GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) {
return context.TargetSymbol as INamedTypeSymbol;
}

private static void Execute(SourceProductionContext context, (INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) {
var (source, className) = OutputGenerator.GenerateOutputs(typeInfo);
private static void Execute(SourceProductionContext context, INamedTypeSymbol typeSymbol) {
var (classSource, className) = OutputGenerator.GenerateClassOutputs(typeSymbol);
var (interfaceSource, interfaceName) = OutputGenerator.GenerateInterfaceOutputs(typeSymbol);

context.AddSource($"{className}.generated.cs", SourceText.From(source, Encoding.UTF8, SourceHashAlgorithm.Sha256));
context.AddSource($"{className}.generated.cs", SourceText.From(classSource, Encoding.UTF8, SourceHashAlgorithm.Sha256));
context.AddSource($"{interfaceName}.generated.cs", SourceText.From(interfaceSource, Encoding.UTF8, SourceHashAlgorithm.Sha256));
}
}
}
145 changes: 91 additions & 54 deletions AsyncTaskOrchestratorGenerator/OutputGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ namespace AsyncTaskOrchestratorGenerator
{
internal static class OutputGenerator
{
public static (string source, string className) GenerateOutputs((INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) {
var type = typeInfo.typeSymbol;
var semanticModel = typeInfo.semanticModel;

var constructorArguments = type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments;
public static (string source, string className) GenerateClassOutputs(INamedTypeSymbol type) {
var constructorArguments = GetAttributeConstructorArguments(type);
var className = constructorArguments.First().Value.ToString();
var executeMethodName = constructorArguments.ElementAt(1).Value.ToString();
var interfaceName = $"I{className}";

var accessModifier = type.DeclaredAccessibility.ToString().ToLower();
var typeMembers = type.GetMembers();
Expand All @@ -36,7 +34,7 @@ public static (string source, string className) GenerateOutputs((INamedTypeSymbo

namespace {type.ContainingNamespace.ToDisplayString()};

{accessModifier} class {className}
{accessModifier} class {className} : {interfaceName}
{{
{string.Join(@"
", formattedFields)}
Expand All @@ -50,37 +48,40 @@ namespace {type.ContainingNamespace.ToDisplayString()};
return (source, className);
}

private static (ExecuteMethodSignatureData, Dictionary<string, TaskData>, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable<IFieldSymbol> fields, string executeMethodName) {
var executeMethod = type
.GetMembers()
.Where(m => m is IMethodSymbol)
.First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol;
var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements;
var variableStatements = statements.Remove(statements.Last());
public static (string source, string interfaceName) GenerateInterfaceOutputs(INamedTypeSymbol type) {
var constructorArguments = GetAttributeConstructorArguments(type);
var className = constructorArguments.First().Value.ToString();
var executeMethodName = constructorArguments.ElementAt(1).Value.ToString();
var interfaceName = $"I{className}";

var accessModifier = type.DeclaredAccessibility.ToString().ToLower();
var executeMethod = GetExecuteMethod(type);
var executeMethodAccessibility = executeMethod.DeclaredAccessibility.ToString().ToLower();
var formattedExecuteMethod = $"{executeMethodAccessibility} {executeMethod.ReturnType} {executeMethodName}();";

var variableData = variableStatements
.Select(s => s as LocalDeclarationStatementSyntax)
.SelectMany(v => v.Declaration.Variables)
.Select(declarationSyntax => {
var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax;
var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax;
var methodCallTypeName = methodAccessExpression.ToString().Split('.').First();
var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type;
var methodCallName = methodAccessExpression.ToString().Split('.').Last();
var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol;
var source =
$@"// <auto-generated/>
#nullable restore

var arguments = invocation.ArgumentList.Arguments;
var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First());
namespace {type.ContainingNamespace.ToDisplayString()};

return new TaskData
{
OutputName = declarationSyntax.Identifier.Text,
MethodCallName = methodAccessExpression.ToString(),
MethodCallReturnType = methodSymbol.ReturnType.ToString(),
DependenciesOutputNames = argumentTypeNames,
TaskName = $"{declarationSyntax.Identifier.Text}Task"
};
});
{accessModifier} interface {interfaceName}
{{
{formattedExecuteMethod}
}}
";
return (source, interfaceName);
}

private static System.Collections.Immutable.ImmutableArray<TypedConstant> GetAttributeConstructorArguments(INamedTypeSymbol type) {
return type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments;
}

private static (ExecuteMethodSignatureData, Dictionary<string, TaskData>, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable<IFieldSymbol> fields, string executeMethodName) {
var executeMethod = GetExecuteMethod(type);
var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements;
var variableStatements = statements.Remove(statements.Last());
var variableData = GetVariableData(fields, variableStatements);

var lastStatement = statements.Last() as ReturnStatementSyntax;
var invocation = lastStatement.Expression as InvocationExpressionSyntax;
Expand Down Expand Up @@ -110,36 +111,55 @@ private static (ExecuteMethodSignatureData, Dictionary<string, TaskData>, TaskDa
}, variableData.ToDictionary(taskData => taskData.OutputName), finalTaskData);
}

private static IEnumerable<TaskData> GetVariableData(IEnumerable<IFieldSymbol> fields, SyntaxList<StatementSyntax> variableStatements) {
return variableStatements
.Select(s => s as LocalDeclarationStatementSyntax)
.SelectMany(v => v.Declaration.Variables)
.Select(declarationSyntax => {
var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax;
var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax;
var methodCallTypeName = methodAccessExpression.ToString().Split('.').First();
var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type;
var methodCallName = methodAccessExpression.ToString().Split('.').Last();
var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol;

var arguments = invocation.ArgumentList.Arguments;
var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First());

return new TaskData
{
OutputName = declarationSyntax.Identifier.Text,
MethodCallName = methodAccessExpression.ToString(),
MethodCallReturnType = methodSymbol.ReturnType.ToString(),
DependenciesOutputNames = argumentTypeNames,
TaskName = $"{declarationSyntax.Identifier.Text}Task"
};
});
}

private static IMethodSymbol GetExecuteMethod(INamedTypeSymbol type) {
return type
.GetMembers()
.Where(m => m is IMethodSymbol)
.First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol;
}

private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureData, Dictionary<string, TaskData> data, TaskData finalTaskData) {
var formattedTaskDeclarations = data.Select(keyValue => {
var item = keyValue.Value;
var hasDependencies = item.DependenciesOutputNames.Any();
return hasDependencies ?
$@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);":
return hasDependencies ?
$@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);" :
$@"var {item.TaskName} = {item.MethodCallName}();";
});

var taskNames = data.Where(keyValue => !keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => keyValue.Value.TaskName);
var formattedTasksList = $@"var tasksToProcess = new List<Task> {{ {string.Join(@", ", taskNames)} }};";

var formattedHandleTaskCompletions = data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => {
var item = keyValue.Value;
var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName);
var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted"));
var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result"));
var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});";
var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});";

return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames})
{{
{formattedCallDependencies}
{formattedAddTaskToList}
}}";
});
var formattedHandleTaskCompletions = CreateFormattedHandleTaskCompletions(data);

var formattedWhenEach = $@"await foreach (var completed in Task.WhenEach(tasksToProcess))
{{
{ string.Join(@"
{string.Join(@"

", formattedHandleTaskCompletions)}
}}";
Expand All @@ -149,11 +169,11 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa
var formattedFinalResult = $@"var finalResult = await {finalTaskData.MethodCallName}({formattedResultDependencyTaskNames});

return finalResult;";

return $@"{signatureData.AccessModifier} async {signatureData.ReturnType} {signatureData.Name}()
{{
{string.Join(@"
", formattedTaskDeclarations) }
", formattedTaskDeclarations)}

{formattedTasksList}

Expand All @@ -163,6 +183,23 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa
}}";
}

private static IEnumerable<string> CreateFormattedHandleTaskCompletions(Dictionary<string, TaskData> data) {
return data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => {
var item = keyValue.Value;
var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName);
var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted"));
var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result"));
var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});";
var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});";

return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames})
{{
{formattedCallDependencies}
{formattedAddTaskToList}
}}";
});
}

private static string FormatConstructor(INamedTypeSymbol type, string className, IEnumerable<ISymbol> typeMembers) {
var constructor = typeMembers
.Where(m => m.Kind == SymbolKind.Method)
Expand Down