Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/System.CommandLine.Tests/Invocation/InvocationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,101 @@ public async Task Nonterminating_option_actions_handle_exceptions_and_return_an_
returnCode.Should().Be(1);
}

[Theory] // https://github.com/dotnet/command-line-api/issues/2771
[InlineData(true)]
[InlineData(false)]
public async Task Nonterminating_option_action_is_invoked_when_command_has_no_action(bool invokeAsync)
{
bool optionActionWasCalled = false;
SynchronousTestAction optionAction = new(_ => optionActionWasCalled = true, terminating: false);

Option<bool> option = new("--test")
{
Action = optionAction
};
RootCommand command = new()
{
option
};

ParseResult parseResult = command.Parse("--test");

if (invokeAsync)
{
await parseResult.InvokeAsync();
}
else
{
parseResult.Invoke();
}

optionActionWasCalled.Should().BeTrue();
}

[Theory] // https://github.com/dotnet/command-line-api/issues/2772
[InlineData(true)]
[InlineData(false)]
public async Task Nonterminating_option_action_return_value_is_propagated(bool invokeAsync)
{
SynchronousTestAction optionAction = new(_ => { }, terminating: false, returnValue: 42);

Option<bool> option = new("--test")
{
Action = optionAction
};
RootCommand command = new()
{
option
};
command.SetAction(_ => { });

ParseResult parseResult = command.Parse("--test");

int result;
if (invokeAsync)
{
result = await parseResult.InvokeAsync();
}
else
{
result = parseResult.Invoke();
}

result.Should().Be(42);
}

[Theory] // https://github.com/dotnet/command-line-api/issues/2772
[InlineData(true)]
[InlineData(false)]
public async Task When_preaction_and_command_action_both_return_nonzero_then_preaction_value_wins(bool invokeAsync)
{
SynchronousTestAction optionAction = new(_ => { }, terminating: false, returnValue: 42);

Option<bool> option = new("--test")
{
Action = optionAction
};
RootCommand command = new()
{
option
};
command.SetAction(_ => 99);

ParseResult parseResult = command.Parse("--test");

int result;
if (invokeAsync)
{
result = await parseResult.InvokeAsync();
}
else
{
result = parseResult.Invoke();
}

result.Should().Be(42);
}

[Fact]
public async Task Command_InvokeAsync_with_cancelation_token_invokes_command_handler()
{
Expand Down
14 changes: 10 additions & 4 deletions src/System.CommandLine.Tests/TestActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ namespace System.CommandLine.Tests;
public class SynchronousTestAction : SynchronousCommandLineAction
{
private readonly Action<ParseResult> _invoke;
private readonly int _returnValue;

public SynchronousTestAction(
Action<ParseResult> invoke,
bool terminating = true,
bool clearsParseErrors = false)
bool clearsParseErrors = false,
int returnValue = 0)
{
ClearsParseErrors = clearsParseErrors;
_invoke = invoke;
_returnValue = returnValue;
Terminating = terminating;
}

Expand All @@ -28,21 +31,24 @@ public SynchronousTestAction(
public override int Invoke(ParseResult parseResult)
{
_invoke(parseResult);
return 0;
return _returnValue;
}
}

public class AsynchronousTestAction : AsynchronousCommandLineAction
{
private readonly Action<ParseResult> _invoke;
private readonly int _returnValue;

public AsynchronousTestAction(
Action<ParseResult> invoke,
bool terminating = true,
bool clearsParseErrors = false)
bool clearsParseErrors = false,
int returnValue = 0)
{
ClearsParseErrors = clearsParseErrors;
_invoke = invoke;
_returnValue = returnValue;
Terminating = terminating;
}

Expand All @@ -53,6 +59,6 @@ public AsynchronousTestAction(
public override Task<int> InvokeAsync(ParseResult parseResult, CancellationToken cancellationToken = default)
{
_invoke(parseResult);
return Task.FromResult(0);
return Task.FromResult(_returnValue);
}
}
106 changes: 65 additions & 41 deletions src/System.CommandLine/Invocation/InvocationPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,52 @@ internal static class InvocationPipeline
{
internal static async Task<int> InvokeAsync(ParseResult parseResult, CancellationToken cancellationToken)
{
if (parseResult.Action is null)
{
return ReturnCodeForMissingAction(parseResult);
}

ProcessTerminationHandler? terminationHandler = null;
using CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

try
{
int exitCode = 0;

if (parseResult.PreActions is not null)
{
for (int i = 0; i < parseResult.PreActions.Count; i++)
{
var action = parseResult.PreActions[i];
int preActionResult;

switch (action)
{
case SynchronousCommandLineAction syncAction:
syncAction.Invoke(parseResult);
preActionResult = syncAction.Invoke(parseResult);
break;
case AsynchronousCommandLineAction asyncAction:
await asyncAction.InvokeAsync(parseResult, cts.Token);
preActionResult = await asyncAction.InvokeAsync(parseResult, cts.Token);
break;
default:
preActionResult = 0;
break;
}

if (exitCode == 0)
{
exitCode = preActionResult;
}
}
}

if (parseResult.Action is null)
{
return exitCode != 0 ? exitCode : ReturnCodeForMissingAction(parseResult);
}

int actionResult;

switch (parseResult.Action)
{
case SynchronousCommandLineAction syncAction:
return syncAction.Invoke(parseResult);
actionResult = syncAction.Invoke(parseResult);
break;

case AsynchronousCommandLineAction asyncAction:
var startedInvocation = asyncAction.InvokeAsync(parseResult, cts.Token);
Expand All @@ -55,20 +69,23 @@ internal static async Task<int> InvokeAsync(ParseResult parseResult, Cancellatio

if (terminationHandler is null)
{
return await startedInvocation;
actionResult = await startedInvocation;
}
else
{
// Handlers may not implement cancellation.
// In such cases, when CancelOnProcessTermination is configured and user presses Ctrl+C,
// ProcessTerminationCompletionSource completes first, with the result equal to native exit code for given signal.
Task<int> firstCompletedTask = await Task.WhenAny(startedInvocation, terminationHandler.ProcessTerminationCompletionSource.Task);
return await firstCompletedTask; // return the result or propagate the exception
actionResult = await firstCompletedTask; // return the result or propagate the exception
}
break;

default:
throw new ArgumentOutOfRangeException(nameof(parseResult.Action));
}

return exitCode != 0 ? exitCode : actionResult;
}
catch (Exception ex) when (parseResult.InvocationConfiguration.EnableDefaultExceptionHandler)
{
Expand All @@ -82,48 +99,55 @@ internal static async Task<int> InvokeAsync(ParseResult parseResult, Cancellatio

internal static int Invoke(ParseResult parseResult)
{
switch (parseResult.Action)
try
{
case null:
return ReturnCodeForMissingAction(parseResult);
int exitCode = 0;

case SynchronousCommandLineAction syncAction:
try
if (parseResult.PreActions is not null)
{
#if DEBUG
for (var i = 0; i < parseResult.PreActions.Count; i++)
{
if (parseResult.PreActions is not null)
var action = parseResult.PreActions[i];

if (action is not SynchronousCommandLineAction)
{
#if DEBUG
for (var i = 0; i < parseResult.PreActions.Count; i++)
{
var action = parseResult.PreActions[i];

if (action is not SynchronousCommandLineAction)
{
parseResult.InvocationConfiguration.EnableDefaultExceptionHandler = false;
throw new Exception(
$"This should not happen. An instance of {nameof(AsynchronousCommandLineAction)} ({action}) was called within {nameof(InvocationPipeline)}.{nameof(Invoke)}. This is supposed to be detected earlier resulting in a call to {nameof(InvocationPipeline)}{nameof(InvokeAsync)}");
}
}
parseResult.InvocationConfiguration.EnableDefaultExceptionHandler = false;
throw new Exception(
$"This should not happen. An instance of {nameof(AsynchronousCommandLineAction)} ({action}) was called within {nameof(InvocationPipeline)}.{nameof(Invoke)}. This is supposed to be detected earlier resulting in a call to {nameof(InvocationPipeline)}{nameof(InvokeAsync)}");
}
}
#endif

for (var i = 0; i < parseResult.PreActions.Count; i++)
for (var i = 0; i < parseResult.PreActions.Count; i++)
{
if (parseResult.PreActions[i] is SynchronousCommandLineAction syncPreAction)
{
int preActionResult = syncPreAction.Invoke(parseResult);
if (exitCode == 0)
{
if (parseResult.PreActions[i] is SynchronousCommandLineAction syncPreAction)
{
syncPreAction.Invoke(parseResult);
}
exitCode = preActionResult;
}
}

return syncAction.Invoke(parseResult);
}
catch (Exception ex) when (parseResult.InvocationConfiguration.EnableDefaultExceptionHandler)
{
return DefaultExceptionHandler(ex, parseResult);
}
}

switch (parseResult.Action)
{
case null:
return exitCode != 0 ? exitCode : ReturnCodeForMissingAction(parseResult);

default:
throw new InvalidOperationException($"{nameof(AsynchronousCommandLineAction)} called within non-async invocation.");
case SynchronousCommandLineAction syncAction:
int actionResult = syncAction.Invoke(parseResult);
return exitCode != 0 ? exitCode : actionResult;

default:
throw new InvalidOperationException($"{nameof(AsynchronousCommandLineAction)} called within non-async invocation.");
}
}
catch (Exception ex) when (parseResult.InvocationConfiguration.EnableDefaultExceptionHandler)
{
return DefaultExceptionHandler(ex, parseResult);
}
}

Expand Down