Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/10744.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved performance of the WebSocket reader with large messages -- by :user:`bdraco`.
25 changes: 14 additions & 11 deletions aiohttp/_websocket/reader_c.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ cdef class WebSocketReader:
cdef int _opcode
cdef bint _frame_fin
cdef int _frame_opcode
cdef object _frame_payload
cdef unsigned long long _frame_payload_len
cdef list _payload_fragments
cdef Py_ssize_t _frame_payload_len

cdef bytes _tail
cdef bint _has_mask
cdef bytes _frame_mask
cdef unsigned long long _payload_length
cdef unsigned int _payload_length_flag
cdef Py_ssize_t _payload_bytes_to_read
cdef unsigned int _payload_len_flag
cdef int _compressed
cdef object _decompressobj
cdef bint _compress
Expand All @@ -97,17 +97,20 @@ cdef class WebSocketReader:
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *

@cython.locals(
start_pos="unsigned int",
data_len="unsigned int",
length="unsigned int",
chunk_size="unsigned int",
chunk_len="unsigned int",
data_length="unsigned int",
start_pos=Py_ssize_t,
data_len=Py_ssize_t,
length=Py_ssize_t,
chunk_size=Py_ssize_t,
chunk_len=Py_ssize_t,
data_len=Py_ssize_t,
data_cstr="const unsigned char *",
first_byte="unsigned char",
second_byte="unsigned char",
end_pos="unsigned int",
f_start_pos=Py_ssize_t,
f_end_pos=Py_ssize_t,
has_mask=bint,
fin=bint,
had_fragments=Py_ssize_t,
payload_bytearray=bytearray,
)
cpdef void _feed_data(self, bytes data) except *
103 changes: 57 additions & 46 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ def __init__(
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._frame_payload: Union[bytes, bytearray] = b""
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0

self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_length = 0
self._payload_length_flag = 0
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress
Expand Down Expand Up @@ -336,13 +336,13 @@ def _feed_data(self, data: bytes) -> None:
data, self._tail = self._tail + data, b""

start_pos: int = 0
data_length = len(data)
data_len = len(data)
data_cstr = data

while True:
# read header
if self._state == READ_HEADER:
if data_length - start_pos < 2:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
Expand Down Expand Up @@ -401,77 +401,88 @@ def _feed_data(self, data: bytes) -> None:
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_length_flag = length
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH

# read payload length
if self._state == READ_PAYLOAD_LENGTH:
length_flag = self._payload_length_flag
if length_flag == 126:
if data_length - start_pos < 2:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_length = first_byte << 8 | second_byte
elif length_flag > 126:
if data_length - start_pos < 8:
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_length = UNPACK_LEN3(data, start_pos)[0]
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_length = length_flag
self._payload_bytes_to_read = len_flag

self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD

# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_length - start_pos < 4:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD

if self._state == READ_PAYLOAD:
chunk_len = data_length - start_pos
if self._payload_length >= chunk_len:
end_pos = data_length
self._payload_length -= chunk_len
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
end_pos = start_pos + self._payload_length
self._payload_length = 0

if self._frame_payload_len:
if type(self._frame_payload) is not bytearray:
self._frame_payload = bytearray(self._frame_payload)
self._frame_payload += data_cstr[start_pos:end_pos]
else:
# Fast path for the first frame
self._frame_payload = data_cstr[start_pos:end_pos]

self._frame_payload_len += end_pos - start_pos
start_pos = end_pos

if self._payload_length != 0:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0

had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos

if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break

if self._has_mask:
payload: Union[bytes, bytearray]
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray()
payload_bytearray.join(self._payload_fragments)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
if type(self._frame_payload) is not bytearray:
self._frame_payload = bytearray(self._frame_payload)
websocket_mask(self._frame_mask, self._frame_payload)
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]

self._handle_frame(
self._frame_fin,
self._frame_opcode,
self._frame_payload,
self._compressed,
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload = b""
self._frame_payload_len = 0
self._state = READ_HEADER

# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = (
data_cstr[start_pos:data_length] if start_pos < data_length else b""
)
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
66 changes: 66 additions & 0 deletions tests/test_benchmarks_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,69 @@ async def run_websocket_benchmark() -> None:
@benchmark
def _run() -> None:
loop.run_until_complete(run_websocket_benchmark())


@pytest.mark.usefixtures("parametrize_zlib_backend")
def test_client_send_large_websocket_compressed_messages(
loop: asyncio.AbstractEventLoop,
aiohttp_client: AiohttpClient,
benchmark: BenchmarkFixture,
) -> None:
"""Benchmark send of compressed WebSocket binary messages."""
message_count = 10
raw_message = b"x" * 2**19 # 512 KiB

async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
for _ in range(message_count):
await ws.receive()
await ws.close()
return ws

app = web.Application()
app.router.add_route("GET", "/", handler)

async def run_websocket_benchmark() -> None:
client = await aiohttp_client(app)
resp = await client.ws_connect("/", compress=15)
for _ in range(message_count):
await resp.send_bytes(raw_message)
await resp.close()

@benchmark
def _run() -> None:
loop.run_until_complete(run_websocket_benchmark())


@pytest.mark.usefixtures("parametrize_zlib_backend")
def test_client_receive_large_websocket_compressed_messages(
loop: asyncio.AbstractEventLoop,
aiohttp_client: AiohttpClient,
benchmark: BenchmarkFixture,
) -> None:
"""Benchmark receive of compressed WebSocket binary messages."""
message_count = 10
raw_message = b"x" * 2**19 # 512 KiB

async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
for _ in range(message_count):
await ws.send_bytes(raw_message)
await ws.close()
return ws

app = web.Application()
app.router.add_route("GET", "/", handler)

async def run_websocket_benchmark() -> None:
client = await aiohttp_client(app)
resp = await client.ws_connect("/", compress=15)
for _ in range(message_count):
await resp.receive()
await resp.close()

@benchmark
def _run() -> None:
loop.run_until_complete(run_websocket_benchmark())
Loading