Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES/10847.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Implemented shared DNS resolver management to fix excessive resolver object creation
when using multiple client sessions. The new ``_DNSResolverManager`` singleton ensures
only one ``DNSResolver`` object is created for default configurations, significantly
reducing resource usage and improving performance for applications using multiple
client sessions simultaneously -- by :user:`bdraco`.
1 change: 1 addition & 0 deletions CHANGES/10902.feature.rst
1 change: 1 addition & 0 deletions CHANGES/9212.breaking.rst
1 change: 1 addition & 0 deletions CHANGES/9212.packaging.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Removed remaining `make_mocked_coro` in the test suite -- by :user:`polkapolka`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ Pavol Vargovčík
Pawel Kowalski
Pawel Miech
Pepe Osca
Phebe Polk
Philipp A.
Pierre-Louis Peeters
Pieter van Beek
Expand Down
8 changes: 4 additions & 4 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Expand Down Expand Up @@ -194,7 +195,7 @@ class _RequestOptions(TypedDict, total=False):
auto_decompress: Union[bool, None]
max_line_size: Union[int, None]
max_field_size: Union[int, None]
middlewares: Optional[Tuple[ClientMiddlewareType, ...]]
middlewares: Optional[Sequence[ClientMiddlewareType]]


@frozen_dataclass_decorator
Expand Down Expand Up @@ -295,7 +296,7 @@ def __init__(
max_line_size: int = 8190,
max_field_size: int = 8190,
fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8",
middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None,
middlewares: Optional[Sequence[ClientMiddlewareType]] = None,
) -> None:
# We initialise _connector to None immediately, as it's referenced in __del__()
# and could cause issues if an exception occurs during initialisation.
Expand Down Expand Up @@ -455,7 +456,7 @@ async def _request(
auto_decompress: Optional[bool] = None,
max_line_size: Optional[int] = None,
max_field_size: Optional[int] = None,
middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None,
middlewares: Optional[Sequence[ClientMiddlewareType]] = None,
) -> ClientResponse:
# NOTE: timeout clamps existing connect and read timeouts. We cannot
# set the default to None because we need to detect if the user wants
Expand Down Expand Up @@ -648,7 +649,6 @@ async def _request(
trust_env=self.trust_env,
)

# Core request handler - now includes connection logic
async def _connect_and_send_request(
req: ClientRequest,
) -> ClientResponse:
Expand Down
7 changes: 2 additions & 5 deletions aiohttp/client_middlewares.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Client middleware support."""

from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence

from .client_reqrep import ClientRequest, ClientResponse

Expand All @@ -17,7 +17,7 @@

def build_client_middlewares(
handler: ClientHandlerType,
middlewares: tuple[ClientMiddlewareType, ...],
middlewares: Sequence[ClientMiddlewareType],
) -> ClientHandlerType:
"""
Apply middlewares to request handler.
Expand All @@ -28,9 +28,6 @@ def build_client_middlewares(
This implementation avoids using partial/update_wrapper to minimize overhead
and doesn't cache to avoid holding references to stateful middleware.
"""
if not middlewares:
return handler

# Optimize for single middleware case
if len(middlewares) == 1:
middleware = middlewares[0]
Expand Down
86 changes: 84 additions & 2 deletions aiohttp/resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import socket
from typing import Any, List, Tuple, Type, Union
import weakref
from typing import Any, List, Optional, Tuple, Type, Union

from .abc import AbstractResolver, ResolveResult

Expand Down Expand Up @@ -88,7 +89,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")

self._resolver = aiodns.DNSResolver(*args, **kwargs)
self._loop = asyncio.get_running_loop()
self._manager: Optional[_DNSResolverManager] = None
# If custom args are provided, create a dedicated resolver instance
# This means each AsyncResolver with custom args gets its own
# aiodns.DNSResolver instance
if args or kwargs:
self._resolver = aiodns.DNSResolver(*args, **kwargs)
return
# Use the shared resolver from the manager for default arguments
self._manager = _DNSResolverManager()
self._resolver = self._manager.get_resolver(self, self._loop)

async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
Expand Down Expand Up @@ -142,7 +153,78 @@ async def resolve(
return hosts

async def close(self) -> None:
if self._manager:
# Release the resolver from the manager if using the shared resolver
self._manager.release_resolver(self, self._loop)
self._manager = None # Clear reference to manager
self._resolver = None # type: ignore[assignment] # Clear reference to resolver
return
# Otherwise cancel our dedicated resolver
self._resolver.cancel()
self._resolver = None # type: ignore[assignment] # Clear reference


class _DNSResolverManager:
"""Manager for aiodns.DNSResolver objects.

This class manages shared aiodns.DNSResolver instances
with no custom arguments across different event loops.
"""

_instance: Optional["_DNSResolverManager"] = None

def __new__(cls) -> "_DNSResolverManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init()
return cls._instance

def _init(self) -> None:
# Use WeakKeyDictionary to allow event loops to be garbage collected
self._loop_data: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop,
tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]],
] = weakref.WeakKeyDictionary()

def get_resolver(
self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
) -> "aiodns.DNSResolver":
"""Get or create the shared aiodns.DNSResolver instance for a specific event loop.

Args:
client: The AsyncResolver instance requesting the resolver.
This is required to track resolver usage.
loop: The event loop to use for the resolver.
"""
# Create a new resolver and client set for this loop if it doesn't exist
if loop not in self._loop_data:
resolver = aiodns.DNSResolver(loop=loop)
client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet()
self._loop_data[loop] = (resolver, client_set)
else:
# Get the existing resolver and client set
resolver, client_set = self._loop_data[loop]

# Register this client with the loop
client_set.add(client)
return resolver

def release_resolver(
self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
) -> None:
"""Release the resolver for an AsyncResolver client when it's closed.

Args:
client: The AsyncResolver instance to release.
loop: The event loop the resolver was using.
"""
# Remove client from its loop's tracking
resolver, client_set = self._loop_data[loop]
client_set.discard(client)
# If no more clients for this loop, cancel and remove its resolver
if not client_set:
resolver.cancel()
del self._loop_data[loop]


_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
Expand Down
25 changes: 4 additions & 21 deletions aiohttp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import contextlib
import gc
import inspect
import ipaddress
import os
import socket
Expand Down Expand Up @@ -42,7 +41,6 @@
from .abc import AbstractCookieJar, AbstractStreamWriter
from .client_reqrep import ClientResponse
from .client_ws import ClientWebSocketResponse
from .helpers import sentinel
from .http import HttpVersion, RawRequestMessage
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import LooseHeaders, StrOrURL
Expand Down Expand Up @@ -682,10 +680,10 @@ def make_mocked_request(

if writer is None:
writer = mock.Mock()
writer.write_headers = make_mocked_coro(None)
writer.write = make_mocked_coro(None)
writer.write_eof = make_mocked_coro(None)
writer.drain = make_mocked_coro(None)
writer.write_headers = mock.AsyncMock(return_value=None)
writer.write = mock.AsyncMock(return_value=None)
writer.write_eof = mock.AsyncMock(return_value=None)
writer.drain = mock.AsyncMock(return_value=None)
writer.transport = transport

protocol.transport = transport
Expand All @@ -701,18 +699,3 @@ def make_mocked_request(
req._match_info = match_info

return req


def make_mocked_coro(
return_value: Any = sentinel, raise_exception: Any = sentinel
) -> Any:
"""Creates a coroutine mock."""

async def mock_coro(*args: Any, **kwargs: Any) -> Any:
if raise_exception is not sentinel:
raise raise_exception
if not inspect.isawaitable(return_value):
return return_value
await return_value

return mock.Mock(wraps=mock_coro)
Loading
Loading