Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/10662.packaging.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Removed non SPDX-license description from ``setup.cfg`` -- by :user:`devanshu-ziphq`.
6 changes: 6 additions & 0 deletions CHANGES/9732.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Added client middleware support -- by :user:`bdraco` and :user:`Dreamsorcerer`.

This change allows users to add middleware to the client session and requests, enabling features like
authentication, logging, and request/response modification without modifying the core
request logic. Additionally, the ``session`` attribute was added to ``ClientRequest``,
allowing middleware to access the session for making additional requests.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Denilson Amorim
Denis Matiychuk
Denis Moshensky
Dennis Kliban
Devanshu Koyalkar
Dima Veselov
Dimitar Dimitrov
Diogo Dutra da Mata
Expand Down
4 changes: 4 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
WSServerHandshakeError,
request,
)
from .client_middlewares import ClientHandlerType, ClientMiddlewareType
from .compression_utils import set_zlib_backend
from .connector import AddrInfoType, SocketFactoryType
from .cookiejar import CookieJar, DummyCookieJar
Expand Down Expand Up @@ -157,6 +158,9 @@
"NamedPipeConnector",
"WSServerHandshakeError",
"request",
# client_middleware
"ClientMiddlewareType",
"ClientHandlerType",
# cookiejar
"CookieJar",
"DummyCookieJar",
Expand Down
81 changes: 56 additions & 25 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
WSMessageTypeError,
WSServerHandshakeError,
)
from .client_middlewares import ClientMiddlewareType, build_client_middlewares
from .client_reqrep import (
SSL_ALLOWED_TYPES,
ClientRequest,
Expand Down Expand Up @@ -193,6 +194,7 @@ class _RequestOptions(TypedDict, total=False):
auto_decompress: Union[bool, None]
max_line_size: Union[int, None]
max_field_size: Union[int, None]
middlewares: Optional[Tuple[ClientMiddlewareType, ...]]


@frozen_dataclass_decorator
Expand Down Expand Up @@ -260,6 +262,7 @@ class ClientSession:
"_default_proxy",
"_default_proxy_auth",
"_retry_connection",
"_middlewares",
)

def __init__(
Expand Down Expand Up @@ -292,6 +295,7 @@ def __init__(
max_line_size: int = 8190,
max_field_size: int = 8190,
fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8",
middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None,
) -> None:
# We initialise _connector to None immediately, as it's referenced in __del__()
# and could cause issues if an exception occurs during initialisation.
Expand Down Expand Up @@ -376,6 +380,7 @@ def __init__(
self._default_proxy = proxy
self._default_proxy_auth = proxy_auth
self._retry_connection: bool = True
self._middlewares = middlewares

def __init_subclass__(cls: Type["ClientSession"]) -> None:
raise TypeError(
Expand Down Expand Up @@ -450,6 +455,7 @@ async def _request(
auto_decompress: Optional[bool] = None,
max_line_size: Optional[int] = None,
max_field_size: Optional[int] = None,
middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None,
) -> ClientResponse:
# NOTE: timeout clamps existing connect and read timeouts. We cannot
# set the default to None because we need to detect if the user wants
Expand Down Expand Up @@ -642,32 +648,33 @@ async def _request(
trust_env=self.trust_env,
)

# connection timeout
try:
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
# Core request handler - now includes connection logic
async def _connect_and_send_request(
req: ClientRequest,
) -> ClientResponse:
# connection timeout
assert self._connector is not None
try:
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
except asyncio.TimeoutError as exc:
raise ConnectionTimeoutError(
f"Connection timeout to host {req.url}"
) from exc

assert conn.protocol is not None
conn.protocol.set_response_params(
timer=timer,
skip_payload=req.method in EMPTY_BODY_METHODS,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
read_timeout=real_timeout.sock_read,
read_bufsize=read_bufsize,
timeout_ceil_threshold=self._connector._timeout_ceil_threshold,
max_line_size=max_line_size,
max_field_size=max_field_size,
)
except asyncio.TimeoutError as exc:
raise ConnectionTimeoutError(
f"Connection timeout to host {url}"
) from exc

assert conn.transport is not None

assert conn.protocol is not None
conn.protocol.set_response_params(
timer=timer,
skip_payload=method in EMPTY_BODY_METHODS,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
read_timeout=real_timeout.sock_read,
read_bufsize=read_bufsize,
timeout_ceil_threshold=self._connector._timeout_ceil_threshold,
max_line_size=max_line_size,
max_field_size=max_field_size,
)

try:
try:
resp = await req.send(conn)
try:
Expand All @@ -678,6 +685,30 @@ async def _request(
except BaseException:
conn.close()
raise
return resp

# Apply middleware (if any) - per-request middleware overrides session middleware
effective_middlewares = (
self._middlewares if middlewares is None else middlewares
)

if effective_middlewares:
handler = build_client_middlewares(
_connect_and_send_request, effective_middlewares
)
else:
handler = _connect_and_send_request

try:
resp = await handler(req)
# Client connector errors should not be retried
except (
ConnectionTimeoutError,
ClientConnectorError,
ClientConnectorCertificateError,
ClientConnectorSSLError,
):
raise
except (ClientOSError, ServerDisconnectedError):
if retry_persistent_connection:
retry_persistent_connection = False
Expand Down
58 changes: 58 additions & 0 deletions aiohttp/client_middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Client middleware support."""

from collections.abc import Awaitable, Callable

from .client_reqrep import ClientRequest, ClientResponse

__all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares")

# Type alias for client request handlers - functions that process requests and return responses
ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]]

# Type for client middleware - similar to server but uses ClientRequest/ClientResponse
ClientMiddlewareType = Callable[
[ClientRequest, ClientHandlerType], Awaitable[ClientResponse]
]


def build_client_middlewares(
handler: ClientHandlerType,
middlewares: tuple[ClientMiddlewareType, ...],
) -> ClientHandlerType:
"""
Apply middlewares to request handler.

The middlewares are applied in reverse order, so the first middleware
in the list wraps all subsequent middlewares and the handler.

This implementation avoids using partial/update_wrapper to minimize overhead
and doesn't cache to avoid holding references to stateful middleware.
"""
if not middlewares:
return handler

# Optimize for single middleware case
if len(middlewares) == 1:
middleware = middlewares[0]

async def single_middleware_handler(req: ClientRequest) -> ClientResponse:
return await middleware(req, handler)

return single_middleware_handler

# Build the chain for multiple middlewares
current_handler = handler

for middleware in reversed(middlewares):
# Create a new closure that captures the current state
def make_wrapper(
mw: ClientMiddlewareType, next_h: ClientHandlerType
) -> ClientHandlerType:
async def wrapped(req: ClientRequest) -> ClientResponse:
return await mw(req, next_h)

return wrapped

current_handler = make_wrapper(middleware, current_handler)

return current_handler
15 changes: 15 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ class ClientRequest:
auth = None
response = None

# These class defaults help create_autospec() work correctly.
# If autospec is improved in future, maybe these can be removed.
url = URL()
method = "GET"

__writer: Optional["asyncio.Task[None]"] = None # async task for streaming data
_continue = None # waiter future for '100 Continue' response

Expand Down Expand Up @@ -362,6 +367,16 @@ def request_info(self) -> RequestInfo:
RequestInfo, (self.url, self.method, headers, self.original_url)
)

@property
def session(self) -> "ClientSession":
"""Return the ClientSession instance.

This property provides access to the ClientSession that initiated
this request, allowing middleware to make additional requests
using the same session.
"""
return self._session

def update_host(self, url: URL) -> None:
"""Update destination host, port and connection type (ssl)."""
# get host/port
Expand Down
Loading
Loading