Skip to content

Commit

Permalink
Add customization to SessionManager StartSession
Browse files Browse the repository at this point in the history
  • Loading branch information
Rangaraju authored and kyleknap committed Sep 11, 2018
1 parent dd7dbac commit 08cdb17
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 1 deletion.
21 changes: 21 additions & 0 deletions awscli/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
import platform
import zipfile
import signal
import contextlib

from botocore.compat import six
#import botocore.compat
Expand Down Expand Up @@ -318,3 +320,22 @@ def get_popen_kwargs_for_pager_cmd(pager_cmd=None):
pager_cmd = shlex.split(pager_cmd)
popen_kwargs['args'] = pager_cmd
return popen_kwargs


@contextlib.contextmanager
def ignore_user_entered_signals():
"""
Ignores user entered signals to avoid process getting killed.
"""
if is_windows:
signal_list = [signal.SIGINT]
else:
signal_list = [signal.SIGINT, signal.SIGQUIT, signal.SIGTSTP]
actual_signals = []
for user_signal in signal_list:
actual_signals.append(signal.signal(user_signal, signal.SIG_IGN))
try:
yield
finally:
for sig, user_signal in enumerate(signal_list):
signal.signal(user_signal, actual_signals[sig])
93 changes: 93 additions & 0 deletions awscli/customizations/sessionmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import logging
import json
import errno

from subprocess import check_call
from awscli.compat import ignore_user_entered_signals
from awscli.clidriver import ServiceOperation, CLIOperationCaller

logger = logging.getLogger(__name__)

ERROR_MESSAGE = (
'SessionManagerPlugin is not found. ',
'Please refer to SessionManager Documentation here: ',
'http://docs.aws.amazon.com/console/systems-manager/',
'session-manager-plugin-not-found'
)


def register_ssm_session(event_handlers):
event_handlers.register('building-command-table.ssm',
add_custom_start_session)


def add_custom_start_session(session, command_table, **kwargs):
command_table['start-session'] = StartSessionCommand(
name='start-session',
parent_name='ssm',
session=session,
operation_model=session.get_service_model(
'ssm').operation_model('StartSession'),
operation_caller=StartSessionCaller(session),
)


class StartSessionCommand(ServiceOperation):

def create_help_command(self):
help_command = super(
StartSessionCommand, self).create_help_command()
# Change the output shape because the command provides no output.
self._operation_model.output_shape = None
return help_command


class StartSessionCaller(CLIOperationCaller):
def invoke(self, service_name, operation_name, parameters,
parsed_globals):
client = self._session.create_client(
service_name, region_name=parsed_globals.region,
endpoint_url=parsed_globals.endpoint_url,
verify=parsed_globals.verify_ssl)
response = client.start_session(**parameters)
session_id = response['SessionId']
region_name = client.meta.region_name

try:
# ignore_user_entered_signals ignores these signals
# because if signals which kills the process are not
# captured would kill the foreground process but not the
# background one. Capturing these would prevents process
# from getting killed and these signals are input to plugin
# and handling in there
with ignore_user_entered_signals():
# call executable with necessary input
check_call(["session-manager-plugin",
json.dumps(response),
region_name,
"StartSession"])
return 0
except OSError as ex:
if ex.errno == errno.ENOENT:
logger.debug('SessionManagerPlugin is not present',
exc_info=True)
# start-session api call returns response and starts the
# session on ssm-agent and response is forwarded to
# session-manager-plugin. If plugin is not present, terminate
# is called so that service and ssm-agent terminates the
# session to avoid zombie session active on ssm-agent for
# default self terminate time
client.terminate_session(SessionId=session_id)
raise ValueError(''.join(ERROR_MESSAGE))
2 changes: 2 additions & 0 deletions awscli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from awscli.customizations.sagemaker import register_alias_sagemaker_runtime_command
from awscli.customizations.servicecatalog import register_servicecatalog_commands
from awscli.customizations.s3events import register_event_stream_arg
from awscli.customizations.sessionmanager import register_ssm_session


def awscli_initialize(event_handlers):
Expand Down Expand Up @@ -166,3 +167,4 @@ def awscli_initialize(event_handlers):
register_history_commands(event_handlers)
register_event_stream_arg(event_handlers)
dlm_initialize(event_handlers)
register_ssm_session(event_handlers)
12 changes: 12 additions & 0 deletions tests/functional/ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
65 changes: 65 additions & 0 deletions tests/functional/ssm/test_start_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import mock
import errno
import json

from awscli.testutils import BaseAWSCommandParamsTest
from awscli.testutils import BaseAWSHelpOutputTest


class TestSessionManager(BaseAWSCommandParamsTest):

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_success(self, mock_check_call):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.return_value = 0
self.parsed_responses = [{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"
}]
self.run_cmd(cmdline, expected_rc=0)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
self.assertEqual(self.operations_called[0][1],
{'Target': 'instance-id'})
actual_response = json.loads(mock_check_call.call_args[0][0][1])
self.assertEqual(
{"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"},
actual_response)

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_fails(self, mock_check_call):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [{
"SessionId": "session-id"
}]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
self.assertEqual(self.operations_called[0][1],
{'Target': 'instance-id'})
self.assertEqual(self.operations_called[1][0].name,
'TerminateSession')
self.assertEqual(self.operations_called[1][1],
{'SessionId': 'session-id'})


class TestHelpOutput(BaseAWSHelpOutputTest):
def test_start_session_output(self):
self.driver.main(['ssm', 'start-session', 'help'])
self.assert_contains('Output\n======\n\nNone')
105 changes: 105 additions & 0 deletions tests/unit/customizations/test_sessionmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import mock
import errno
import json
import botocore.session

from awscli.customizations import sessionmanager
from awscli.testutils import unittest


class TestSessionManager(unittest.TestCase):

def setUp(self):
self.session = mock.Mock(botocore.session.Session)
self.client = mock.Mock()
self.region = 'us-west-2'
self.client.meta.region_name = self.region
self.session.create_client.return_value = self.client
self.caller = sessionmanager.StartSessionCaller(self.session)

def test_start_session_when_non_custom_start_session_fails(self):
self.client.start_session.side_effect = Exception('some exception')
params = {}
with self.assertRaisesRegexp(Exception, 'some exception'):
self.caller.invoke('ssm', 'StartSession', params, mock.Mock())

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_success_scenario(self, mock_check_call):
mock_check_call.return_value = 0

start_session_params = {
"Target": "i-123456789"
}

start_session_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"
}

self.client.start_session.return_value = start_session_response

rc = self.caller.invoke('ssm', 'StartSession',
start_session_params, mock.Mock())
self.assertEquals(rc, 0)
self.client.start_session.assert_called_with(**start_session_params)
mock_check_call_list = mock_check_call.call_args[0][0]
mock_check_call_list[1] = json.loads(mock_check_call_list[1])
self.assertEqual(
mock_check_call_list,
['session-manager-plugin',
start_session_response,
self.region,
'StartSession']
)

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_when_check_call_fails(self, mock_check_call):
mock_check_call.side_effect = OSError(errno.ENOENT, 'some error')

start_session_params = {
"Target": "i-123456789"
}

start_session_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"
}

terminate_session_params = {
"SessionId": "session-id"
}

self.client.start_session.return_value = start_session_response

with self.assertRaises(ValueError):
self.caller.invoke('ssm', 'StartSession',
start_session_params, mock.Mock())

self.client.start_session.assert_called_with(
**start_session_params)
self.client.terminate_session.assert_called_with(
**terminate_session_params)

mock_check_call_list = mock_check_call.call_args[0][0]
mock_check_call_list[1] = json.loads(mock_check_call_list[1])
self.assertEqual(
mock_check_call_list,
['session-manager-plugin',
start_session_response,
self.region,
'StartSession']
)
29 changes: 28 additions & 1 deletion tests/unit/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import os
import signal

from nose.tools import assert_equal
from botocore.compat import six

from awscli.compat import ensure_text_type
from awscli.compat import compat_shell_quote
from awscli.compat import get_popen_kwargs_for_pager_cmd
from awscli.testutils import mock, unittest
from awscli.compat import ignore_user_entered_signals
from awscli.testutils import mock, unittest, skip_if_windows


class TestEnsureText(unittest.TestCase):
Expand Down Expand Up @@ -110,3 +114,26 @@ def test_non_windows(self):
def test_non_windows_specific_pager(self):
kwargs = get_popen_kwargs_for_pager_cmd('more')
self.assertEqual({'args': ['more']}, kwargs)


class TestIgnoreUserSignals(unittest.TestCase):
@skip_if_windows("These signals are not supported for windows")
def test_ignore_signal_sigint(self):
with ignore_user_entered_signals():
try:
os.kill(os.getpid(), signal.SIGINT)
except KeyboardInterrupt:
self.fail('The ignore_user_entered_signals context '
'manager should have ignored')

@skip_if_windows("These signals are not supported for windows")
def test_ignore_signal_sigquit(self):
with ignore_user_entered_signals():
self.assertEqual(signal.getsignal(signal.SIGQUIT), signal.SIG_IGN)
os.kill(os.getpid(), signal.SIGQUIT)

@skip_if_windows("These signals are not supported for windows")
def test_ignore_signal_sigtstp(self):
with ignore_user_entered_signals():
self.assertEqual(signal.getsignal(signal.SIGTSTP), signal.SIG_IGN)
os.kill(os.getpid(), signal.SIGTSTP)

0 comments on commit 08cdb17

Please sign in to comment.