Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import functools
import os
import random
import socket
import sys
Expand All @@ -14,6 +15,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, cast

import aiofastnet
import aiohappyeyeballs
from aiohappyeyeballs import AddrInfoType, SocketFactoryType
from multidict import CIMultiDict
Expand Down Expand Up @@ -1256,7 +1258,11 @@ async def _wrap_create_connection(
and sys.version_info >= (3, 11)
):
kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout
return await self._loop.create_connection(*args, **kwargs, sock=sock)

if os.environ.get("USE_AIOFN", 0):
return await aiofastnet.create_connection(self._loop, *args, **kwargs, sock=sock)
else:
return await self._loop.create_connection(*args, **kwargs, sock=sock)
except cert_errors as exc:
raise ClientConnectorCertificateError(req.connection_key, exc) from exc
except ssl_errors as exc:
Expand Down Expand Up @@ -1330,23 +1336,25 @@ async def _start_tls_connection(
):
try:
# ssl_shutdown_timeout is only available in Python 3.11+
if sys.version_info >= (3, 11) and self._ssl_shutdown_timeout:
tls_transport = await self._loop.start_tls(
underlying_transport,
tls_proto,
sslcontext,
server_hostname=req.server_hostname or req.url.raw_host,
ssl_handshake_timeout=timeout.total,
ssl_shutdown_timeout=self._ssl_shutdown_timeout,
)
if os.environ.get("USE_AIOFN"):
if self._ssl_shutdown_timeout:
start_tls = functools.partial(aiofastnet.start_tls, self._loop, ssl_shutdown_timeout=self._ssl_shutdown_timeout)
else:
start_tls = functools.partial(aiofastnet.start_tls, self._loop)
else:
tls_transport = await self._loop.start_tls(
if sys.version_info >= (3, 11) and self._ssl_shutdown_timeout:
start_tls = functools.partial(self._loop.start_tls, ssl_shutdown_timeout=self._ssl_shutdown_timeout)
else:
start_tls = self._loop.start_tls

tls_transport = await start_tls(
underlying_transport,
tls_proto,
sslcontext,
server_hostname=req.server_hostname or req.url.raw_host,
ssl_handshake_timeout=timeout.total,
ssl_handshake_timeout=timeout.total
)

except BaseException:
# We need to close the underlying transport since
# `start_tls()` probably failed before it had a
Expand Down
6 changes: 5 additions & 1 deletion aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from stat import S_ISREG
from types import MappingProxyType
from typing import IO, TYPE_CHECKING, Any, Final, Optional
import aiofastnet

from . import hdrs
from .abc import AbstractStreamWriter
Expand Down Expand Up @@ -131,7 +132,10 @@ async def _sendfile(
assert transport is not None

try:
await loop.sendfile(transport, fobj, offset, count)
if os.environ.get("USE_AIOFN"):
await aiofastnet.sendfile(loop, transport, fobj, offset, count)
else:
await loop.sendfile(transport, fobj, offset, count)
except NotImplementedError:
return await self._sendfile_fallback(writer, fobj, offset, count)

Expand Down
11 changes: 9 additions & 2 deletions aiohttp/web_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import os
import signal
import socket
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Generic, TypeVar

import aiofastnet
from yarl import URL

from .abc import AbstractAccessLogger, AbstractStreamWriter
Expand Down Expand Up @@ -130,7 +133,8 @@ async def start(self) -> None:
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server(
create_server = partial(aiofastnet.create_server, loop) if os.environ.get("USE_AIOFN") else loop.create_server
self._server = await create_server(
server,
self._host,
self._port,
Expand All @@ -139,6 +143,7 @@ async def start(self) -> None:
reuse_address=self._reuse_address,
reuse_port=self._reuse_port,
)

if self._server.sockets:
self._bound_port = self._server.sockets[0].getsockname()[1]
else:
Expand Down Expand Up @@ -244,7 +249,9 @@ async def start(self) -> None:
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server(
create_server = partial(aiofastnet.create_server, loop) if os.environ.get("USE_AIOFN") else loop.create_server

self._server = await create_server(
server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog
)

Expand Down
106 changes: 106 additions & 0 deletions examples/ktls_static_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python3
import argparse
import asyncio
import os

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'os' is not used.
import pathlib
import ssl
import tempfile
from logging import basicConfig, getLogger
import uvloop

from aiohttp import web


HOST = "0.0.0.0"
TLS_PORT = 8443
KTLS_PORT = 8444
FILE_SIZE = 2 * 1024 * 1024 * 1024
STATIC_DIR = pathlib.Path(tempfile.gettempdir()) / "aiohttp-ktls-static"
HUGE_FILE = STATIC_DIR / "huge.bin"


def make_huge_file() -> pathlib.Path:
STATIC_DIR.mkdir(parents=True, exist_ok=True)
if not HUGE_FILE.exists() or HUGE_FILE.stat().st_size != FILE_SIZE:
with HUGE_FILE.open("wb") as f:
f.truncate(FILE_SIZE)
return HUGE_FILE


def make_ssl_context(*, enable_ktls: bool) -> ssl.SSLContext:
here = pathlib.Path(__file__).parent
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(here / "server.crt", here / "server.key")

if enable_ktls:
ssl_context.options |= ssl.OP_ENABLE_KTLS

return ssl_context


async def huge_file(request: web.Request) -> web.FileResponse:
return web.FileResponse(make_huge_file())


def make_app() -> web.Application:
app = web.Application()
app.router.add_get("/huge.bin", huge_file)
return app


async def main(args) -> None:
huge_path = make_huge_file()

if args.asyncio_debug:
asyncio.get_running_loop().set_debug(True)

runner = web.AppRunner(make_app())
await runner.setup()

plain_tls_site = web.TCPSite(
runner,
args.host,
args.tls_port,
ssl_context=make_ssl_context(enable_ktls=False),
)
ktls_site = web.TCPSite(
runner,
args.host,
args.ktls_port,
ssl_context=make_ssl_context(enable_ktls=True),
)

try:
await plain_tls_site.start()
await ktls_site.start()

print(f"Serving {huge_path} ({FILE_SIZE} bytes)")
print(f"TLS without KTLS: https://{args.host}:{plain_tls_site.port}/huge.bin")
print(f"TLS with KTLS: https://{args.host}:{ktls_site.port}/huge.bin")

await asyncio.Event().wait()
finally:
await runner.cleanup()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Serve one 50 MiB file from two HTTPS ports, one with KTLS enabled."
)
)
parser.add_argument("--host", default=HOST)
parser.add_argument("--tls-port", type=int, default=TLS_PORT)
parser.add_argument("--ktls-port", type=int, default=KTLS_PORT)
parser.add_argument("--uvloop", action="store_true", help="Use uvloop")
parser.add_argument("--asyncio-debug", action="store_true", help="Enable loop debugging")
parser.add_argument("--level", type=str, default="INFO", help="Logging level")

args = parser.parse_args()

if args.uvloop:
uvloop.install()

basicConfig(level=args.level)

asyncio.run(main(args))
1 change: 1 addition & 0 deletions requirements/runtime-deps.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
aiodns >= 3.3.0
aiohappyeyeballs >= 2.5.0
aiosignal >= 1.4.0
aiofastnet >= 0.6.0
async-timeout >= 4.0, < 6.0 ; python_version < '3.11'
backports.zstd; platform_python_implementation == 'CPython' and python_version < '3.14'
Brotli >= 1.2; platform_python_implementation == 'CPython'
Expand Down
2 changes: 2 additions & 0 deletions requirements/runtime-deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ aiohappyeyeballs==2.6.1
# via -r requirements/runtime-deps.in
aiosignal==1.4.0
# via -r requirements/runtime-deps.in
aiofastnet >= 0.6.0
# via -r requirements/runtime-deps.in
async-timeout==5.0.1 ; python_version < "3.11"
# via -r requirements/runtime-deps.in
backports-zstd==1.3.0 ; platform_python_implementation == "CPython" and python_version < "3.14"
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def blockbuster(request: pytest.FixtureRequest) -> Iterator[None]:
# synchronization in async code.
# Allow lock.acquire calls to prevent these false positives
bb.functions["threading.Lock.acquire"].deactivate()

# aiofastnet is using sendfile on a non-blocking socket.
# blockbuster triggers anyway. Seems like a false positive
bb.functions["os.sendfile"].deactivate()
yield


Expand Down
Loading