Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions awscli/customizations/ecs/executecommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import logging
import json
import errno
import os

from subprocess import check_call
from subprocess import check_call, check_output
from awscli.compat import ignore_user_entered_signals
from awscli.clidriver import ServiceOperation, CLIOperationCaller
from awscli.customizations.sessionmanager import VersionRequirement

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,6 +76,9 @@ def build_ssm_request_paramaters(response, client):


class ExecuteCommandCaller(CLIOperationCaller):
LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR = "1.2.497.0"
DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"

def invoke(self, service_name, operation_name, parameters, parsed_globals):
try:
# making an execute-command call to connect to an
Expand All @@ -83,7 +88,9 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals):
# before calling execute-command to ensure that
# session-manager-plugin is installed
# before execute-command-command is made
check_call(["session-manager-plugin"])
plugin_version = check_output(
["session-manager-plugin", "--version"], text=True
)
client = self._session.create_client(
service_name, region_name=parsed_globals.region,
endpoint_url=parsed_globals.endpoint_url,
Expand All @@ -94,6 +101,15 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals):
if self._session.profile is not None else ''
endpoint_url = client.meta.endpoint_url
ssm_request_params = build_ssm_request_paramaters(response, client)
start_session_response = json.dumps(response['session'])
ssm_env_name = self.DEFAULT_SSM_ENV_NAME
env = os.environ.copy()
version_requirement = VersionRequirement(
min_version=self.LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR
)
if version_requirement.meets_requirement(plugin_version):
env[ssm_env_name] = start_session_response
start_session_response = ssm_env_name
# ignore_user_entered_signals ignores these signals
# because if signals which kills the process are not
# captured would kill the foreground process but not the
Expand All @@ -103,12 +119,12 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals):
with ignore_user_entered_signals():
# call executable with necessary input
check_call(["session-manager-plugin",
json.dumps(response['session']),
start_session_response,
region_name,
"StartSession",
profile_name,
json.dumps(ssm_request_params),
endpoint_url])
endpoint_url], env=env)
return 0
except OSError as ex:
if ex.errno == errno.ENOENT:
Expand Down
96 changes: 74 additions & 22 deletions tests/unit/customizations/ecs/test_executecommand_startsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,28 @@ def setUp(self):
}

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_calls_fails_from_ecs(self, mock_check_call):
@mock.patch('awscli.customizations.ecs.executecommand.check_output')
def test_when_calls_fails_from_ecs(self, mock_check_output, mock_check_call):
self.client.execute_command.side_effect = Exception('some exception')
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.0.0\n"
with self.assertRaisesRegex(Exception, 'some exception'):
self.caller.invoke('ecs', 'ExecuteCommand', {}, mock.Mock())

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_session_manager_plugin_not_installed(self, mock_check_call):
mock_check_call.side_effect = [OSError(errno.ENOENT, 'some error'), 0]
@mock.patch('awscli.customizations.ecs.executecommand.check_output')
def test_when_session_manager_plugin_not_installed(
self, mock_check_output
):
mock_check_output.side_effect = OSError(errno.ENOENT, 'some error')

with self.assertRaises(ValueError):
self.caller.invoke('ecs', 'ExecuteCommand', {}, mock.Mock())

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_execute_command_success(self, mock_check_call):
@mock.patch('awscli.customizations.ecs.executecommand.check_output')
def test_execute_command_success(self, mock_check_output, mock_check_call):
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.0.0\n"

self.client.execute_command.return_value = \
self.execute_command_response
Expand All @@ -142,12 +148,16 @@ def test_execute_command_success(self, mock_check_call):
self.profile,
json.dumps(self.ssm_request_parameters),
self.endpoint_url
]
],
)

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_describe_task_fails(self, mock_check_call):
@mock.patch('awscli.customizations.ecs.executecommand.check_output')
def test_when_describe_task_fails(
self, mock_check_output, mock_check_call
):
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.0.0\n"

self.client.execute_command.return_value = \
self.execute_command_response
Expand All @@ -162,8 +172,12 @@ def test_when_describe_task_fails(self, mock_check_call):
assert_called_with(**self.execute_command_params)

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_describe_task_returns_no_tasks(self, mock_check_call):
@mock.patch('awscli.customizations.ecs.executecommand.check_output')
def test_when_describe_task_returns_no_tasks(
self, mock_check_output, mock_check_call
):
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.0.0\n"

self.client.execute_command.return_value = \
self.execute_command_response
Expand All @@ -178,8 +192,10 @@ def test_when_describe_task_returns_no_tasks(self, mock_check_call):
assert_called_with(**self.execute_command_params)

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_check_call_fails(self, mock_check_call):
mock_check_call.side_effect = [0, Exception('some Exception')]
@mock.patch('awscli.customizations.ecs.executecommand.check_output')
def test_when_check_call_fails(self, mock_check_output, mock_check_call):
mock_check_call.side_effect = Exception('some Exception')
mock_check_output.return_value = "1.2.0.0\n"

self.client.execute_command.return_value = \
self.execute_command_response
Expand All @@ -189,15 +205,51 @@ def test_when_check_call_fails(self, mock_check_call):
self.caller.invoke('ecs', 'ExecuteCommand',
self.execute_command_params, mock.Mock())

mock_check_call_list = mock_check_call.call_args[0][0]
mock_check_call_list[1] = json.loads(mock_check_call_list[1])
self.assertEqual(
mock_check_call_list,
['session-manager-plugin',
self.execute_command_response["session"],
self.region,
'StartSession',
self.profile,
json.dumps(self.ssm_request_parameters),
self.endpoint_url],
)
mock_check_call_list = mock_check_call.call_args[0][0]
mock_check_call_list[1] = json.loads(mock_check_call_list[1])
self.assertEqual(
mock_check_call_list,
['session-manager-plugin',
self.execute_command_response["session"],
self.region,
'StartSession',
self.profile,
json.dumps(self.ssm_request_parameters),
self.endpoint_url],
)

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
@mock.patch('awscli.customizations.ecs.executecommand.check_output')
def test_execute_command_uses_env_var_with_new_plugin(
self, mock_check_output, mock_check_call
):
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.500.0\n"

self.client.execute_command.return_value = \
self.execute_command_response
self.client.describe_tasks.return_value = self.describe_tasks_response
ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE"

rc = self.caller.invoke('ecs', 'ExecuteCommand',
self.execute_command_params, mock.Mock())

self.assertEqual(rc, 0)
mock_check_call_list = mock_check_call.call_args[0][0]
self.assertEqual(
mock_check_call_list,
[
'session-manager-plugin',
ssm_env_name,
self.region,
'StartSession',
self.profile,
json.dumps(self.ssm_request_parameters),
self.endpoint_url,
],
)
env = mock_check_call.call_args[1]["env"]
self.assertEqual(
env[ssm_env_name],
json.dumps(self.execute_command_response["session"]),
)