Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/10988.bugfix.rst
4 changes: 4 additions & 0 deletions CHANGES/2914.doc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Improved documentation for middleware by adding warnings and examples about
request body stream consumption. The documentation now clearly explains that
request body streams can only be read once and provides best practices for
sharing parsed request data between middleware and handlers -- by :user:`bdraco`.
1 change: 1 addition & 0 deletions CHANGES/6009.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed :py:attr:`~aiohttp.web.WebSocketResponse.prepared` property to correctly reflect the prepared state, especially during timeout scenarios -- by :user:`bdraco`
4 changes: 4 additions & 0 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ def can_prepare(self, request: BaseRequest) -> WebSocketReady:
else:
return WebSocketReady(True, protocol)

@property
def prepared(self) -> bool:
return self._writer is not None

@property
def closed(self) -> bool:
return self._closed
Expand Down
97 changes: 92 additions & 5 deletions docs/web_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,13 @@ A *middleware* is a coroutine that can modify either the request or
response. For example, here's a simple *middleware* which appends
``' wink'`` to the response::

from aiohttp.web import middleware
from aiohttp import web
from typing import Callable, Awaitable

async def middleware(request, handler):
async def middleware(
request: web.Request,
handler: Callable[[web.Request], Awaitable[web.StreamResponse]]
) -> web.StreamResponse:
resp = await handler(request)
resp.text = resp.text + ' wink'
return resp
Expand Down Expand Up @@ -619,18 +623,25 @@ post-processing like handling *CORS* and so on.
The following code demonstrates middlewares execution order::

from aiohttp import web
from typing import Callable, Awaitable

async def test(request):
async def test(request: web.Request) -> web.Response:
print('Handler function called')
return web.Response(text="Hello")

async def middleware1(request, handler):
async def middleware1(
request: web.Request,
handler: Callable[[web.Request], Awaitable[web.StreamResponse]]
) -> web.StreamResponse:
print('Middleware 1 called')
response = await handler(request)
print('Middleware 1 finished')
return response

async def middleware2(request, handler):
async def middleware2(
request: web.Request,
handler: Callable[[web.Request], Awaitable[web.StreamResponse]]
) -> web.StreamResponse:
print('Middleware 2 called')
response = await handler(request)
print('Middleware 2 finished')
Expand All @@ -649,6 +660,82 @@ Produced output::
Middleware 2 finished
Middleware 1 finished

Request Body Stream Consumption
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. warning::

When middleware reads the request body (using :meth:`~aiohttp.web.BaseRequest.read`,
:meth:`~aiohttp.web.BaseRequest.text`, :meth:`~aiohttp.web.BaseRequest.json`, or
:meth:`~aiohttp.web.BaseRequest.post`), the body stream is consumed. However, these
high-level methods cache their result, so subsequent calls from the handler or other
middleware will return the same cached value.

The important distinction is:

- High-level methods (:meth:`~aiohttp.web.BaseRequest.read`, :meth:`~aiohttp.web.BaseRequest.text`,
:meth:`~aiohttp.web.BaseRequest.json`, :meth:`~aiohttp.web.BaseRequest.post`) cache their
results internally, so they can be called multiple times and will return the same value.
- Direct stream access via :attr:`~aiohttp.web.BaseRequest.content` does NOT have this
caching behavior. Once you read from ``request.content`` directly (e.g., using
``await request.content.read()``), subsequent reads will return empty bytes.

Consider this middleware that logs request bodies::

from aiohttp import web
from typing import Callable, Awaitable

async def logging_middleware(
request: web.Request,
handler: Callable[[web.Request], Awaitable[web.StreamResponse]]
) -> web.StreamResponse:
# This consumes the request body stream
body = await request.text()
print(f"Request body: {body}")
return await handler(request)

async def handler(request: web.Request) -> web.Response:
# This will return the same value that was read in the middleware
# (i.e., the cached result, not an empty string)
body = await request.text()
return web.Response(text=f"Received: {body}")

In contrast, when accessing the stream directly (not recommended in middleware)::

async def stream_middleware(
request: web.Request,
handler: Callable[[web.Request], Awaitable[web.StreamResponse]]
) -> web.StreamResponse:
# Reading directly from the stream - this consumes it!
data = await request.content.read()
print(f"Stream data: {data}")
return await handler(request)

async def handler(request: web.Request) -> web.Response:
# This will return empty bytes because the stream was already consumed
data = await request.content.read()
# data will be b'' (empty bytes)

# However, high-level methods would still work if called for the first time:
# body = await request.text() # This would read from internal cache if available
return web.Response(text=f"Received: {data}")

When working with raw stream data that needs to be shared between middleware and handlers::

async def stream_parsing_middleware(
request: web.Request,
handler: Callable[[web.Request], Awaitable[web.StreamResponse]]
) -> web.StreamResponse:
# Read stream once and store the data
raw_data = await request.content.read()
request['raw_body'] = raw_data
return await handler(request)

async def handler(request: web.Request) -> web.Response:
# Access the stored data instead of reading the stream again
raw_data = request.get('raw_body', b'')
return web.Response(body=raw_data)

Example
^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,11 @@ and :ref:`aiohttp-web-signals` handlers::
of closing.
:const:`~aiohttp.WSMsgType.CLOSE` message has been received from peer.

.. attribute:: prepared

Read-only :class:`bool` property, ``True`` if :meth:`prepare` has
been called, ``False`` otherwise.

.. attribute:: close_code

Read-only property, close code from peer. It is set to ``None`` on
Expand Down
3 changes: 3 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3408,6 +3408,9 @@ async def handler(request: web.Request) -> NoReturn:
pass

assert cm._session.closed
# Allow event loop to process transport cleanup
# on Python < 3.11
await asyncio.sleep(0)


async def test_aiohttp_request_ctx_manager_not_found() -> None:
Expand Down
7 changes: 6 additions & 1 deletion tests/test_client_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,13 @@ async def test_client_middleware_retry_reuses_connection(
aiohttp_server: AiohttpServer,
) -> None:
"""Test that connections are reused when middleware performs retries."""
request_count = 0

async def handler(request: web.Request) -> web.Response:
nonlocal request_count
request_count += 1
if request_count == 1:
return web.Response(status=400) # First request returns 400 with no body
return web.Response(text="OK")

class TrackingConnector(TCPConnector):
Expand All @@ -891,7 +896,7 @@ async def __call__(
while True:
self.attempt_count += 1
response = await handler(request)
if retry_count == 0:
if response.status == 400 and retry_count == 0:
retry_count += 1
continue
return response
Expand Down
2 changes: 1 addition & 1 deletion tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,4 +670,4 @@ async def test_get_extra_info(
await ws.prepare(req)
ws._writer = ws_transport

assert ws.get_extra_info(valid_key, default_value) == expected_result
assert expected_result == ws.get_extra_info(valid_key, default_value)
113 changes: 113 additions & 0 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,3 +1332,116 @@ async def handler(request: web.Request) -> web.WebSocketResponse:
)
await client.server.close()
assert close_code == WSCloseCode.OK


async def test_websocket_prepare_timeout_close_issue(
loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient
) -> None:
"""Test that WebSocket can handle prepare with early returns.

This is a regression test for issue #6009 where the prepared property
incorrectly checked _payload_writer instead of _writer.
"""

async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
assert ws.can_prepare(request)
await ws.prepare(request)
await ws.send_str("test")
await ws.close()
return ws

app = web.Application()
app.router.add_route("GET", "/ws", handler)
client = await aiohttp_client(app)

# Connect via websocket
ws = await client.ws_connect("/ws")
msg = await ws.receive()
assert msg.type is WSMsgType.TEXT
assert msg.data == "test"
await ws.close()


async def test_websocket_prepare_timeout_from_issue_reproducer(
loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient
) -> None:
"""Test websocket behavior when prepare is interrupted.

This test verifies the fix for issue #6009 where close() would
fail after prepare() was interrupted.
"""
prepare_complete = asyncio.Event()
close_complete = asyncio.Event()

async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()

# Prepare the websocket
await ws.prepare(request)
prepare_complete.set()

# Send a message to confirm connection works
await ws.send_str("connected")

# Wait for client to close
msg = await ws.receive()
assert msg.type is WSMsgType.CLOSE
await ws.close()
close_complete.set()

return ws

app = web.Application()
app.router.add_route("GET", "/ws", handler)
client = await aiohttp_client(app)

# Connect and verify the connection works
ws = await client.ws_connect("/ws")
await prepare_complete.wait()

msg = await ws.receive()
assert msg.type is WSMsgType.TEXT
assert msg.data == "connected"

# Close the connection
await ws.close()
await close_complete.wait()


async def test_websocket_prepared_property(
loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient
) -> None:
"""Test that WebSocketResponse.prepared property correctly reflects state."""
prepare_called = asyncio.Event()

async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()

# Initially not prepared
initial_state = ws.prepared
assert not initial_state

# After prepare() is called, should be prepared
await ws.prepare(request)
prepare_called.set()

# Check prepared state
prepared_state = ws.prepared
assert prepared_state

# Send a message to verify the connection works
await ws.send_str("test")
await ws.close()
return ws

app = web.Application()
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)

ws = await client.ws_connect("/")
await prepare_called.wait()
msg = await ws.receive()
assert msg.type is WSMsgType.TEXT
assert msg.data == "test"
await ws.close()
Loading