Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions kubernetes/base/stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
ERROR_CHANNEL = 3
RESIZE_CHANNEL = 4

V4_CHANNEL_PROTOCOL = "v4.channel.k8s.io"
V5_CHANNEL_PROTOCOL = "v5.channel.k8s.io"

class _IgnoredIO:
def write(self, _x):
pass
Expand All @@ -59,13 +62,16 @@ def __init__(self, configuration, url, headers, capture_all, binary=False):
"""
self._connected = False
self._channels = {}
self._closed_channels = set()
self.subprotocol = None
self.binary = binary
self.newline = '\n' if not self.binary else b'\n'
if capture_all:
self._all = StringIO() if not self.binary else BytesIO()
else:
self._all = _IgnoredIO()
self.sock = create_websocket(configuration, url, headers)
self.subprotocol = getattr(self.sock, 'subprotocol', None)
self._connected = True
self._returncode = None

Expand Down Expand Up @@ -93,6 +99,7 @@ def readline_channel(self, channel, timeout=None):
timeout = float("inf")
start = time.time()
while self.is_open() and time.time() - start < timeout:
# Always try to drain the channel first
if channel in self._channels:
data = self._channels[channel]
if self.newline in data:
Expand All @@ -104,6 +111,14 @@ def readline_channel(self, channel, timeout=None):
else:
del self._channels[channel]
return ret

if channel in self._closed_channels:
if channel in self._channels:
ret = self._channels[channel]
del self._channels[channel]
return ret
return b"" if self.binary else ""

self.update(timeout=(timeout - time.time() + start))

def write_channel(self, channel, data):
Expand All @@ -119,6 +134,14 @@ def write_channel(self, channel, data):
payload = channel_prefix + data
self.sock.send(payload, opcode=opcode)

def close_channel(self, channel):
"""Close a channel (v5 protocol only)."""
if self.subprotocol != V5_CHANNEL_PROTOCOL:
return
data = bytes([255, channel])
self.sock.send(data, opcode=ABNF.OPCODE_BINARY)
self._closed_channels.add(channel)

def peek_stdout(self, timeout=0):
"""Same as peek_channel with channel=1."""
return self.peek_channel(STDOUT_CHANNEL, timeout=timeout)
Expand Down Expand Up @@ -200,13 +223,24 @@ def update(self, timeout=0):
return
elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT:
data = frame.data
if six.PY3 and not self.binary:
data = data.decode("utf-8", "replace")
if len(data) > 1:
if len(data) > 0:
# Parse channel from raw bytes to support v5 CLOSE signal AND avoid charset issues
channel = data[0]
if six.PY3 and not self.binary:
channel = ord(channel)
# In Py3, iterating bytes gives int, but indexing bytes gives int.
# websocket-client frame.data might be bytes.

if channel == 255 and self.subprotocol == V5_CHANNEL_PROTOCOL: # v5 CLOSE
if len(data) > 1:
# data[1] is already int in Py3 bytes
close_chan = data[1]
self._closed_channels.add(close_chan)
return

data = data[1:]
# Decode data if expected text
if not self.binary:
data = data.decode("utf-8", "replace")

if data:
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
# keeping all messages in the order they received
Expand Down Expand Up @@ -469,7 +503,7 @@ def create_websocket(configuration, url, headers=None):
header.append("sec-websocket-protocol: %s" %
headers['sec-websocket-protocol'])
else:
header.append("sec-websocket-protocol: v4.channel.k8s.io")
header.append("sec-websocket-protocol: %s,%s" % (V5_CHANNEL_PROTOCOL, V4_CHANNEL_PROTOCOL))

if url.startswith('wss://') and configuration.verify_ssl:
ssl_opts = {
Expand Down
116 changes: 115 additions & 1 deletion kubernetes/base/stream/ws_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
# limitations under the License.

import unittest
from unittest.mock import MagicMock, patch

from .ws_client import get_websocket_url
from . import ws_client as ws_client_module
from .ws_client import get_websocket_url, WSClient, V5_CHANNEL_PROTOCOL, V4_CHANNEL_PROTOCOL
from .ws_client import websocket_proxycare
from kubernetes.client.configuration import Configuration
import os
import socket
import threading
import pytest
from kubernetes import stream, client, config
import websocket

try:
import urllib3
Expand Down Expand Up @@ -123,6 +126,117 @@ def test_websocket_proxycare(self):
assert dictval(connect_opts, 'http_proxy_auth') == expect_auth
assert dictval(connect_opts, 'http_no_proxy') == expect_noproxy


class WSClientProtocolTest(unittest.TestCase):
"""Tests for WSClient V5 protocol handling"""

def setUp(self):
# Mock configuration to avoid real connections in WSClient.__init__
self.config_mock = MagicMock()
self.config_mock.assert_hostname = False
self.config_mock.api_key = {}
self.config_mock.proxy = None
self.config_mock.ssl_ca_cert = None
self.config_mock.cert_file = None
self.config_mock.key_file = None
self.config_mock.verify_ssl = True

def test_create_websocket_header(self):
"""Verify sec-websocket-protocol header requests v5 first"""
# Patch WebSocket class in the module
with patch.object(ws_client_module, 'WebSocket', autospec=True) as mock_ws_cls:
mock_ws = mock_ws_cls.return_value

WSClient(self.config_mock, "ws://test", headers=None, capture_all=True)

mock_ws.connect.assert_called_once()
call_args = mock_ws.connect.call_args
# connect(url, **options)
# check kwargs for 'header'
kwargs = call_args[1]
self.assertIn('header', kwargs)
expected_header = f"sec-websocket-protocol: {V5_CHANNEL_PROTOCOL},{V4_CHANNEL_PROTOCOL}"
self.assertIn(expected_header, kwargs['header'])

def test_close_channel_v5(self):
"""Verify close_channel sends correct frame when v5 is negotiated"""
with patch.object(ws_client_module, 'create_websocket') as mock_create:
mock_ws = MagicMock()
mock_ws.subprotocol = V5_CHANNEL_PROTOCOL
mock_ws.connected = True
mock_create.return_value = mock_ws

client = WSClient(self.config_mock, "ws://test", headers=None, capture_all=True)
client.close_channel(0)

mock_ws.send.assert_called_with(b'\xff\x00', opcode=websocket.ABNF.OPCODE_BINARY)

def test_close_channel_v4(self):
"""Verify close_channel does nothing when v4 is negotiated"""
with patch.object(ws_client_module, 'create_websocket') as mock_create:
mock_ws = MagicMock()
mock_ws.subprotocol = V4_CHANNEL_PROTOCOL
mock_ws.connected = True
mock_create.return_value = mock_ws

client = WSClient(self.config_mock, "ws://test", headers=None, capture_all=True)
client.close_channel(0)

mock_ws.send.assert_not_called()

def test_update_receives_close_v5(self):
"""Verify update processes close signal when v5 is negotiated"""
with patch.object(ws_client_module, 'create_websocket') as mock_create, \
patch('select.select') as mock_select:

mock_ws = MagicMock()
mock_ws.subprotocol = V5_CHANNEL_PROTOCOL
mock_ws.connected = True
mock_ws.sock.fileno.return_value = 10

# Setup frame with close signal for channel 0
frame = MagicMock()
frame.data = b'\xff\x00'
mock_ws.recv_data_frame.return_value = (websocket.ABNF.OPCODE_BINARY, frame)

mock_create.return_value = mock_ws
# Make select return ready
mock_select.return_value = ([mock_ws.sock], [], [])

client = WSClient(self.config_mock, "ws://test", headers=None, capture_all=True)
client.update(timeout=0)

self.assertIn(0, client._closed_channels)

def test_update_ignores_close_signal_v4(self):
"""Verify update treats 0xFF as regular data (or ignores signal interpretation) when v4"""
with patch.object(ws_client_module, 'create_websocket') as mock_create, \
patch('select.select') as mock_select:

mock_ws = MagicMock()
mock_ws.subprotocol = V4_CHANNEL_PROTOCOL
mock_ws.connected = True
mock_ws.sock.fileno.return_value = 10

# Setup frame that looks like close signal but should be treated as data
frame = MagicMock()
frame.data = b'\xff\x00'
mock_ws.recv_data_frame.return_value = (websocket.ABNF.OPCODE_BINARY, frame)

mock_create.return_value = mock_ws
mock_select.return_value = ([mock_ws.sock], [], [])

client = WSClient(self.config_mock, "ws://test", headers=None, capture_all=True, binary=True) # binary=True to avoid decode errors
client.update(timeout=0)

# Should NOT be in closed channels
self.assertNotIn(0, client._closed_channels)
# Should be in data channels (channel 255 with data \x00)
# Code: channel = data[0] (255), data = data[1:] (\x00)
# if channel (255) not in _channels...
self.assertIn(255, client._channels)
self.assertEqual(client._channels[255], b'\x00')

@pytest.fixture(scope="module")
def dummy_proxy():
#Dummy Proxy
Expand Down
57 changes: 57 additions & 0 deletions kubernetes/e2e_test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,63 @@ def test_pod_apis(self):
resp = api.delete_namespaced_pod(name=name, body={},
namespace='default')

def test_pod_exec_close_channel(self):
"""Test sending CLOSE signal for a channel (v5 protocol)."""
client = api_client.ApiClient(configuration=self.config)
api = core_v1_api.CoreV1Api(client)

name = 'busybox-test-' + short_uuid()
pod_manifest = manifest_with_command(
name, "while true;do date;sleep 5; done")

resp = api.create_namespaced_pod(body=pod_manifest, namespace='default')
self.assertEqual(name, resp.metadata.name)

# Wait for pod to be running
timeout = time.time() + 60
while True:
resp = api.read_namespaced_pod(name=name, namespace='default')
if resp.status.phase == 'Running':
break
if time.time() > timeout:
self.fail("Timeout waiting for pod to be running")
time.sleep(1)

# Use cat to echo stdin to stdout.
# When stdin is closed, cat should exit, terminating the command.
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command=['/bin/sh', '-c', 'cat'],
stderr=True, stdin=True,
stdout=True, tty=False,
_preload_content=False)

if resp.subprotocol != "v5.channel.k8s.io":
resp.close()
api.delete_namespaced_pod(name=name, body={}, namespace='default')
self.skipTest("Skipping test: v5.channel.k8s.io subprotocol not negotiated")

try:
resp.write_stdin("test-close\n")
line = resp.readline_stdout(timeout=5)
self.assertEqual("test-close", line)

# Close stdin (channel 0)
# This should send EOF to cat, causing it to exit.
resp.close_channel(0)

# Wait for process to exit
resp.update(timeout=5)
start = time.time()
while resp.is_open() and time.time() - start < 10:
resp.update(timeout=1)

self.assertFalse(resp.is_open(), "Connection should close after cat exits")
self.assertEqual(resp.returncode, 0)
finally:
if resp.is_open():
resp.close()
api.delete_namespaced_pod(name=name, body={}, namespace='default')

def test_exit_code(self):
client = api_client.ApiClient(configuration=self.config)
api = core_v1_api.CoreV1Api(client)
Expand Down