Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/10941.bugfix.rst
1 change: 1 addition & 0 deletions CHANGES/10943.bugfix.rst
1 change: 1 addition & 0 deletions CHANGES/10946.feature.rst
1 change: 1 addition & 0 deletions CHANGES/10952.feature.rst
2 changes: 0 additions & 2 deletions aiohttp/client_middleware_digest_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,6 @@ async def __call__(
# Check if we need to authenticate
if not self._authenticate(response):
break
elif retry_count < 1:
response.release() # Release the response to enable connection reuse on retry

# At this point, response is guaranteed to be defined
assert response is not None
Expand Down
5 changes: 4 additions & 1 deletion aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,10 @@ async def write_with_length(

try:
while True:
chunk = await self._iter.__anext__()
if sys.version_info >= (3, 10):
chunk = await anext(self._iter)
else:
chunk = await self._iter.__anext__()
if remaining_bytes is None:
await writer.write(chunk)
# If we have a content length limit
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ async def close(self) -> None:
self._resolver = None # type: ignore[assignment] # Clear reference to resolver
return
# Otherwise cancel our dedicated resolver
self._resolver.cancel()
if self._resolver is not None:
self._resolver.cancel()
self._resolver = None # type: ignore[assignment] # Clear reference


Expand Down
10 changes: 4 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from hashlib import md5, sha1, sha256
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, Generator, Iterator
from typing import Any, AsyncIterator, Callable, Generator, Iterator
from unittest import mock
from uuid import uuid4

Expand Down Expand Up @@ -338,15 +338,13 @@ def parametrize_zlib_backend(


@pytest.fixture()
def cleanup_payload_pending_file_closes(
async def cleanup_payload_pending_file_closes(
loop: asyncio.AbstractEventLoop,
) -> Generator[None, None, None]:
) -> AsyncIterator[None]:
"""Ensure all pending file close operations complete during test teardown."""
yield
if payload._CLOSE_FUTURES:
# Only wait for futures from the current loop
loop_futures = [f for f in payload._CLOSE_FUTURES if f.get_loop() is loop]
if loop_futures:
loop.run_until_complete(
asyncio.gather(*loop_futures, return_exceptions=True)
)
await asyncio.gather(*loop_futures, return_exceptions=True)
125 changes: 125 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4382,3 +4382,128 @@ async def handler(request: web.Request) -> web.Response:
response.raise_for_status()

assert len(client._session.connector._conns) == 1


async def test_post_content_exception_connection_kept(
aiohttp_client: AiohttpClient,
) -> None:
"""Test that connections are kept after content.set_exception() with POST."""

async def handler(request: web.Request) -> web.Response:
await request.read()
return web.Response(
body=b"x" * 1000
) # Larger response to ensure it's not pre-buffered

app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)

# POST request with body - connection should be closed after content exception
resp = await client.post("/", data=b"request body")

with pytest.raises(RuntimeError):
async with resp:
assert resp.status == 200
resp.content.set_exception(RuntimeError("Simulated error"))
await resp.read()

assert resp.closed

# Wait for any pending operations to complete
await resp.wait_for_close()

assert client._session.connector is not None
# Connection is kept because content.set_exception() is a client-side operation
# that doesn't affect the underlying connection state
assert len(client._session.connector._conns) == 1


async def test_network_error_connection_closed(
aiohttp_client: AiohttpClient,
) -> None:
"""Test that connections are closed after network errors."""

async def handler(request: web.Request) -> NoReturn:
# Read the request body
await request.read()

# Start sending response but close connection before completing
response = web.StreamResponse()
response.content_length = 1000 # Promise 1000 bytes
await response.prepare(request)

# Send partial data then force close the connection
await response.write(b"x" * 100) # Only send 100 bytes
# Force close the transport to simulate network error
assert request.transport is not None
request.transport.close()
assert False, "Will not return"

app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)

# POST request that will fail due to network error
with pytest.raises(aiohttp.ClientPayloadError):
resp = await client.post("/", data=b"request body")
async with resp:
await resp.read() # This should fail

# Give event loop a chance to process connection cleanup
await asyncio.sleep(0)

assert client._session.connector is not None
# Connection should be closed due to network error
assert len(client._session.connector._conns) == 0


async def test_client_side_network_error_connection_closed(
aiohttp_client: AiohttpClient,
) -> None:
"""Test that connections are closed after client-side network errors."""
handler_done = asyncio.Event()

async def handler(request: web.Request) -> NoReturn:
# Read the request body
await request.read()

# Start sending a large response
response = web.StreamResponse()
response.content_length = 10000 # Promise 10KB
await response.prepare(request)

# Send some data
await response.write(b"x" * 1000)

# Keep the response open - we'll interrupt from client side
await asyncio.wait_for(handler_done.wait(), timeout=5.0)
assert False, "Will not return"

app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)

# POST request that will fail due to client-side network error
with pytest.raises(aiohttp.ClientPayloadError):
resp = await client.post("/", data=b"request body")
async with resp:
# Simulate client-side network error by closing the transport
# This simulates connection reset, network failure, etc.
assert resp.connection is not None
assert resp.connection.protocol is not None
assert resp.connection.protocol.transport is not None
resp.connection.protocol.transport.close()

# This should fail with connection error
await resp.read()

# Signal handler to finish
handler_done.set()

# Give event loop a chance to process connection cleanup
await asyncio.sleep(0)

assert client._session.connector is not None
# Connection should be closed due to client-side network error
assert len(client._session.connector._conns) == 0
1 change: 0 additions & 1 deletion tests/test_client_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,6 @@ async def __call__(
response = await handler(request)
if retry_count == 0:
retry_count += 1
response.release() # Release the response to enable connection reuse
continue
return response

Expand Down
37 changes: 37 additions & 0 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,40 @@ async def test_dns_resolver_manager_missing_loop_data() -> None:

# Verify no exception was raised
assert loop not in manager._loop_data


@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required")
@pytest.mark.usefixtures("check_no_lingering_resolvers")
async def test_async_resolver_close_multiple_times() -> None:
"""Test that AsyncResolver.close() can be called multiple times without error."""
with patch("aiodns.DNSResolver") as mock_dns_resolver:
mock_resolver = Mock()
mock_resolver.cancel = Mock()
mock_dns_resolver.return_value = mock_resolver

# Create a resolver with custom args (dedicated resolver)
resolver = AsyncResolver(nameservers=["8.8.8.8"])

# Close it once
await resolver.close()
mock_resolver.cancel.assert_called_once()

# Close it again - should not raise AttributeError
await resolver.close()
# cancel should still only be called once
mock_resolver.cancel.assert_called_once()


@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required")
@pytest.mark.usefixtures("check_no_lingering_resolvers")
async def test_async_resolver_close_with_none_resolver() -> None:
"""Test that AsyncResolver.close() handles None resolver gracefully."""
with patch("aiodns.DNSResolver"):
# Create a resolver with custom args (dedicated resolver)
resolver = AsyncResolver(nameservers=["8.8.8.8"])

# Manually set resolver to None to simulate edge case
resolver._resolver = None # type: ignore[assignment]

# This should not raise AttributeError
await resolver.close()
3 changes: 3 additions & 0 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,9 @@ async def handler(request: web.Request) -> web.Response:
await resp.read()
assert resp.closed

# Wait for any pending operations to complete
await resp.wait_for_close()

assert session._connector is not None
assert len(session._connector._conns) == 1

Expand Down
Loading