Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/azure-cli/azure/cli/command_modules/appservice/_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,8 @@
examples:
- name: Create a remote connection using a tcp tunnel to your web app
text: az webapp create-remote-connection --name MyWebApp --resource-group MyResourceGroup
- name: Create a remote connection to a specific instance
text: az webapp create-remote-connection --name MyWebApp --resource-group MyResourceGroup --instance 89c07485c4742abcde3f0e19ea4402a06e3b48145ed81e6468066f10a78074b1
"""

helps['webapp delete'] = """
Expand Down Expand Up @@ -2394,6 +2396,9 @@
- name: ssh into a web app
text: >
az webapp ssh -n MyUniqueAppName -g MyResourceGroup
- name: ssh into a specific instance of a web app
text: >
az webapp ssh -n MyUniqueAppName -g MyResourceGroup --instance 89c07485c4742abcde3f0e19ea4402a06e3b48145ed81e6468066f10a78074b1
"""

helps['webapp start'] = """
Expand Down
16 changes: 12 additions & 4 deletions src/azure-cli/azure/cli/command_modules/appservice/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,14 +977,22 @@ def load_arguments(self, _):
with self.argument_context('webapp ssh') as c:
c.argument('port', options_list=['--port', '-p'],
help='Port for the remote connection. Default: Random available port', type=int)
c.argument('timeout', options_list=['--timeout', '-t'], help='timeout in seconds. Defaults to none', type=int)
c.argument('instance', options_list=['--instance', '-i'], help='Webapp instance to connect to. Defaults to none.')
c.argument('timeout', options_list=['--timeout', '-t'],
help='Timeout in seconds. The tunnel will automatically close after this duration. '
'Defaults to none (keep open until manually closed).', type=int)
c.argument('instance', options_list=['--instance', '-i'],
help='Webapp instance to connect to. Use `az webapp list-instances` to get available instances. '
'If not specified, connects to an arbitrary instance.')

with self.argument_context('webapp create-remote-connection') as c:
c.argument('port', options_list=['--port', '-p'],
help='Port for the remote connection. Default: Random available port', type=int)
c.argument('timeout', options_list=['--timeout', '-t'], help='timeout in seconds. Defaults to none', type=int)
c.argument('instance', options_list=['--instance', '-i'], help='Webapp instance to connect to. Defaults to none.')
c.argument('timeout', options_list=['--timeout', '-t'],
help='Timeout in seconds. The tunnel will automatically close after this duration. '
'Defaults to none (keep open until manually closed).', type=int)
c.argument('instance', options_list=['--instance', '-i'],
help='Webapp instance to connect to. Use `az webapp list-instances` to get available instances. '
'If not specified, connects to an arbitrary instance.')

with self.argument_context('webapp vnet-integration') as c:
c.argument('name', arg_type=webapp_name_arg_type, id_part=None)
Expand Down
52 changes: 42 additions & 10 deletions src/azure-cli/azure/cli/command_modules/appservice/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -9407,6 +9407,8 @@ def get_tunnel(cmd, resource_group_name, name, port=None, slot=None, instance=No
def create_tunnel(cmd, resource_group_name, name, port=None, slot=None, timeout=None, instance=None):
tunnel_server = get_tunnel(cmd, resource_group_name, name, port, slot, instance)

_register_tunnel_cleanup(tunnel_server)

t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
t.daemon = True
t.start()
Expand All @@ -9425,16 +9427,23 @@ def create_tunnel(cmd, resource_group_name, name, port=None, slot=None, timeout=

logger.warning('Ctrl + C to close')

if timeout:
time.sleep(int(timeout))
else:
while t.is_alive():
time.sleep(5)
try:
if timeout:
time.sleep(int(timeout))
else:
while t.is_alive():
time.sleep(5)
except KeyboardInterrupt:
logger.warning('Shutting down tunnel...')
finally:
tunnel_server.close()


def create_tunnel_and_session(cmd, resource_group_name, name, port=None, slot=None, timeout=None, instance=None):
tunnel_server = get_tunnel(cmd, resource_group_name, name, port, slot, instance)

_register_tunnel_cleanup(tunnel_server)

t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
t.daemon = True
t.start()
Expand All @@ -9447,11 +9456,16 @@ def create_tunnel_and_session(cmd, resource_group_name, name, port=None, slot=No
s.daemon = True
s.start()

if timeout:
time.sleep(int(timeout))
else:
while s.is_alive() and t.is_alive():
time.sleep(5)
try:
if timeout:
time.sleep(int(timeout))
else:
while s.is_alive() and t.is_alive():
time.sleep(5)
except KeyboardInterrupt:
logger.warning('Shutting down tunnel...')
finally:
tunnel_server.close()


def perform_onedeploy_functionapp(cmd,
Expand Down Expand Up @@ -9918,6 +9932,24 @@ def _start_tunnel(tunnel_server):
tunnel_server.start_server()


def _register_tunnel_cleanup(tunnel_server):
"""Register signal handlers and atexit to ensure the tunnel is cleaned up."""
import atexit
import signal

def _cleanup():
tunnel_server.close()

atexit.register(_cleanup)

def _signal_handler(signum, frame): # pylint: disable=unused-argument
logger.warning('Received signal %s, shutting down tunnel...', signum)
tunnel_server.close()

signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)


def _start_ssh_session(hostname, port, username, password):
tries = 0
while True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
update_app_settings,
update_application_settings_polling,
update_webapp)
from azure.cli.command_modules.appservice.tunnel import TunnelServer

# pylint: disable=line-too-long
from azure.cli.core.profiles import ResourceType
Expand Down Expand Up @@ -639,6 +640,174 @@ def test_update_webapp_platform_release_channel_latest(self):
self.assertEqual(result.additional_properties["properties"]["platformReleaseChannel"], "Latest")


class TestTunnelServer(unittest.TestCase):
"""Tests for TunnelServer reliability and cleanup improvements."""

@mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket')
def test_tunnel_server_close_sets_closing_event(self, mock_socket_cls):
mock_sock = mock.MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.getsockname.return_value = ('127.0.0.1', 12345)
server = TunnelServer.__new__(TunnelServer)
server.local_addr = '127.0.0.1'
server.local_port = 0
server.remote_addr = 'testapp.scm.azurewebsites.net'
server.auth_string = 'Basic dGVzdDp0ZXN0'
server.instance = None
server.client = None
server.ws = None
from threading import Event
server._closing = Event()
server.sock = mock_sock

server.close()
self.assertTrue(server._closing.is_set())
mock_sock.close.assert_called_once()

@mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket')
def test_tunnel_server_close_is_idempotent(self, mock_socket_cls):
mock_sock = mock.MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.getsockname.return_value = ('127.0.0.1', 12345)
server = TunnelServer.__new__(TunnelServer)
server.local_addr = '127.0.0.1'
server.local_port = 0
server.remote_addr = 'testapp.scm.azurewebsites.net'
server.auth_string = 'Basic dGVzdDp0ZXN0'
server.instance = None
server.client = None
server.ws = None
from threading import Event
server._closing = Event()
server.sock = mock_sock

server.close()
server.close()
# Socket.close should only be called once
mock_sock.close.assert_called_once()

@mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket')
def test_tunnel_server_close_handles_ws_and_client(self, mock_socket_cls):
mock_sock = mock.MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.getsockname.return_value = ('127.0.0.1', 12345)
server = TunnelServer.__new__(TunnelServer)
server.local_addr = '127.0.0.1'
server.local_port = 0
server.remote_addr = 'testapp.scm.azurewebsites.net'
server.auth_string = 'Basic dGVzdDp0ZXN0'
server.instance = None
server.client = mock.MagicMock()
server.ws = mock.MagicMock()
from threading import Event
server._closing = Event()
server.sock = mock_sock

server.close()
server.ws.close.assert_called_once()
server.client.close.assert_called_once()
mock_sock.close.assert_called_once()

@mock.patch('azure.cli.command_modules.appservice.tunnel.create_connection')
@mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket')
def test_create_websocket_connection_retries_on_failure(self, mock_socket_cls, mock_create_conn):
mock_sock = mock.MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.getsockname.return_value = ('127.0.0.1', 12345)
server = TunnelServer.__new__(TunnelServer)
server.local_addr = '127.0.0.1'
server.local_port = 0
server.remote_addr = 'testapp.scm.azurewebsites.net'
server.auth_string = 'Basic dGVzdDp0ZXN0'
server.instance = None
server.client = None
server.ws = None
from threading import Event
server._closing = Event()
server.sock = mock_sock

mock_ws = mock.MagicMock()
# Fail twice, succeed on third
mock_create_conn.side_effect = [ConnectionError("fail1"), ConnectionError("fail2"), mock_ws]

with mock.patch.object(server._closing, 'wait'):
result = server._create_websocket_connection(
'wss://test/Tunnel.ashx', ['Authorization: Basic dGVzdDp0ZXN0'], 0)

self.assertEqual(result, mock_ws)
self.assertEqual(mock_create_conn.call_count, 3)

@mock.patch('azure.cli.command_modules.appservice.tunnel.create_connection')
@mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket')
def test_create_websocket_connection_raises_after_max_retries(self, mock_socket_cls, mock_create_conn):
mock_sock = mock.MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.getsockname.return_value = ('127.0.0.1', 12345)
server = TunnelServer.__new__(TunnelServer)
server.local_addr = '127.0.0.1'
server.local_port = 0
server.remote_addr = 'testapp.scm.azurewebsites.net'
server.auth_string = 'Basic dGVzdDp0ZXN0'
server.instance = None
server.client = None
server.ws = None
from threading import Event
server._closing = Event()
server.sock = mock_sock

mock_create_conn.side_effect = ConnectionError("always fail")

with mock.patch.object(server._closing, 'wait'):
with self.assertRaises(CLIError) as ctx:
server._create_websocket_connection(
'wss://test/Tunnel.ashx', ['Authorization: Basic dGVzdDp0ZXN0'], 0)

self.assertIn('Failed to establish WebSocket tunnel connection', str(ctx.exception))

@mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket')
def test_keepalive_ping_stops_on_event(self, mock_socket_cls):
from threading import Event
mock_sock = mock.MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.getsockname.return_value = ('127.0.0.1', 12345)
server = TunnelServer.__new__(TunnelServer)
server.local_addr = '127.0.0.1'
server.local_port = 0
server.remote_addr = 'testapp.scm.azurewebsites.net'
server.auth_string = 'Basic dGVzdDp0ZXN0'
server.instance = None
server.client = None
server.ws = None
server._closing = Event()
server.sock = mock_sock

mock_ws = mock.MagicMock()
mock_ws.connected = True
stop_event = Event()
# Signal stop immediately so the keepalive loop runs once at most
stop_event.set()
server._send_keepalive_pings(mock_ws, 1, stop_event)
# Should not crash; ws.ping may or may not have been called depending on timing


class TestTunnelSignalCleanup(unittest.TestCase):
"""Tests for signal handler registration and cleanup in create_tunnel / create_tunnel_and_session."""

@mock.patch('signal.signal')
@mock.patch('atexit.register')
def test_register_tunnel_cleanup_registers_handlers(self, mock_atexit, mock_signal):
from azure.cli.command_modules.appservice.custom import _register_tunnel_cleanup
import signal

mock_tunnel = mock.MagicMock()
_register_tunnel_cleanup(mock_tunnel)

mock_atexit.assert_called_once()
signal_calls = {call[0][0] for call in mock_signal.call_args_list}
self.assertIn(signal.SIGINT, signal_calls)
self.assertIn(signal.SIGTERM, signal_calls)


class FakedResponse: # pylint: disable=too-few-public-methods
def __init__(self, status_code):
self.status_code = status_code
Expand Down
Loading
Loading