Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES/9798.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Allow user setting zlib compression backend -- by :user:`TimMenninger`

This change allows the user to call :func:`aiohttp.set_zlib_backend()` with the
zlib compression module of their choice. Default behavior continues to use
the builtin ``zlib`` library.
2 changes: 2 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
WSServerHandshakeError,
request,
)
from .compression_utils import set_zlib_backend
from .connector import AddrInfoType, SocketFactoryType
from .cookiejar import CookieJar, DummyCookieJar
from .formdata import FormData
Expand Down Expand Up @@ -165,6 +166,7 @@
"BasicAuth",
"ChainMapProxy",
"ETag",
"set_zlib_backend",
# http
"HttpVersion",
"HttpVersion10",
Expand Down
18 changes: 13 additions & 5 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,23 @@ def _feed_data(self, data: bytes) -> None:
self._decompressobj = ZLibDecompressor(
suppress_deflate_header=True
)
# XXX: It's possible that the zlib backend (isal is known to
# do this, maybe others too?) will return max_length bytes,
# but internally buffer more data such that the payload is
# >max_length, so we return one extra byte and if we're able
# to do that, then the message is too big.
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
assembled_payload + WS_DEFLATE_TRAILING,
(
self._max_msg_size + 1
if self._max_msg_size
else self._max_msg_size
),
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message size {self._max_msg_size + left}"
f" exceeds limit {self._max_msg_size}",
f"Decompressed message exceeds size limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
Expand Down
9 changes: 5 additions & 4 deletions aiohttp/_websocket/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import asyncio
import random
import zlib
from functools import partial
from typing import Any, Final, Optional, Union

from ..base_protocol import BaseProtocol
from ..client_exceptions import ClientConnectionResetError
from ..compression_utils import ZLibCompressor
from ..compression_utils import ZLibBackend, ZLibCompressor
from .helpers import (
MASK_LEN,
MSG_SIZE,
Expand Down Expand Up @@ -95,7 +94,9 @@ async def send_frame(
message = (
await compressobj.compress(message)
+ compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
ZLibBackend.Z_FULL_FLUSH
if self.notakeover
else ZLibBackend.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING)
# Its critical that we do not return control to the event
Expand Down Expand Up @@ -160,7 +161,7 @@ async def send_frame(

def _make_compress_obj(self, compress: int) -> ZLibCompressor:
return ZLibCompressor(
level=zlib.Z_BEST_SPEED,
level=ZLibBackend.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
Expand Down
3 changes: 1 addition & 2 deletions aiohttp/abc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import socket
import zlib
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel
Expand Down Expand Up @@ -217,7 +216,7 @@ async def drain(self) -> None:

@abstractmethod
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
self, encoding: str = "deflate", strategy: Optional[int] = None
) -> None:
"""Enable HTTP body compression"""

Expand Down
139 changes: 118 additions & 21 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import zlib
from concurrent.futures import Executor
from typing import Optional, cast
from typing import Any, Final, Optional, Protocol, TypedDict, cast

if sys.version_info >= (3, 12):
from collections.abc import Buffer
Expand All @@ -24,14 +24,113 @@
MAX_SYNC_CHUNK_SIZE = 1024


class ZLibCompressObjProtocol(Protocol):
def compress(self, data: Buffer) -> bytes: ...
def flush(self, mode: int = ..., /) -> bytes: ...


class ZLibDecompressObjProtocol(Protocol):
def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
def flush(self, length: int = ..., /) -> bytes: ...

@property
def eof(self) -> bool: ...


class ZLibBackendProtocol(Protocol):
MAX_WBITS: int
Z_FULL_FLUSH: int
Z_SYNC_FLUSH: int
Z_BEST_SPEED: int
Z_FINISH: int

def compressobj(
self,
level: int = ...,
method: int = ...,
wbits: int = ...,
memLevel: int = ...,
strategy: int = ...,
zdict: Optional[Buffer] = ...,
) -> ZLibCompressObjProtocol: ...
def decompressobj(
self, wbits: int = ..., zdict: Buffer = ...
) -> ZLibDecompressObjProtocol: ...

def compress(
self, data: Buffer, /, level: int = ..., wbits: int = ...
) -> bytes: ...
def decompress(
self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
) -> bytes: ...


class CompressObjArgs(TypedDict, total=False):
wbits: int
strategy: int
level: int


class ZLibBackendWrapper:
def __init__(self, _zlib_backend: ZLibBackendProtocol):
self._zlib_backend: ZLibBackendProtocol = _zlib_backend

@property
def name(self) -> str:
return getattr(self._zlib_backend, "__name__", "undefined")

@property
def MAX_WBITS(self) -> int:
return self._zlib_backend.MAX_WBITS

@property
def Z_FULL_FLUSH(self) -> int:
return self._zlib_backend.Z_FULL_FLUSH

@property
def Z_SYNC_FLUSH(self) -> int:
return self._zlib_backend.Z_SYNC_FLUSH

@property
def Z_BEST_SPEED(self) -> int:
return self._zlib_backend.Z_BEST_SPEED

@property
def Z_FINISH(self) -> int:
return self._zlib_backend.Z_FINISH

def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
return self._zlib_backend.compressobj(*args, **kwargs)

def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
return self._zlib_backend.decompressobj(*args, **kwargs)

def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
return self._zlib_backend.compress(data, *args, **kwargs)

def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
return self._zlib_backend.decompress(data, *args, **kwargs)

# Everything not explicitly listed in the Protocol we just pass through
def __getattr__(self, attrname: str) -> Any:
return getattr(self._zlib_backend, attrname)


ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)


def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
ZLibBackend._zlib_backend = new_zlib_backend


def encoding_to_mode(
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
) -> int:
if encoding == "gzip":
return 16 + zlib.MAX_WBITS
return 16 + ZLibBackend.MAX_WBITS

return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS
return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS


class ZlibBaseHandler:
Expand All @@ -53,7 +152,7 @@ def __init__(
suppress_deflate_header: bool = False,
level: Optional[int] = None,
wbits: Optional[int] = None,
strategy: int = zlib.Z_DEFAULT_STRATEGY,
strategy: Optional[int] = None,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
Expand All @@ -66,12 +165,15 @@ def __init__(
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
if level is None:
self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy)
else:
self._compressor = zlib.compressobj(
wbits=self._mode, strategy=strategy, level=level
)
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)

kwargs: CompressObjArgs = {}
kwargs["wbits"] = self._mode
if strategy is not None:
kwargs["strategy"] = strategy
if level is not None:
kwargs["level"] = level
self._compressor = self._zlib_backend.compressobj(**kwargs)
self._compress_lock = asyncio.Lock()

def compress_sync(self, data: Buffer) -> bytes:
Expand Down Expand Up @@ -100,8 +202,10 @@ async def compress(self, data: Buffer) -> bytes:
)
return self.compress_sync(data)

def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
return self._compressor.flush(mode)
def flush(self, mode: Optional[int] = None) -> bytes:
return self._compressor.flush(
mode if mode is not None else self._zlib_backend.Z_FINISH
)


class ZLibDecompressor(ZlibBaseHandler):
Expand All @@ -117,7 +221,8 @@ def __init__(
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
self._decompressor = zlib.decompressobj(wbits=self._mode)
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)

def decompress_sync(self, data: Buffer, max_length: int = 0) -> bytes:
return self._decompressor.decompress(data, max_length)
Expand Down Expand Up @@ -149,14 +254,6 @@ def flush(self, length: int = 0) -> bytes:
def eof(self) -> bool:
return self._decompressor.eof

@property
def unconsumed_tail(self) -> bytes:
return self._decompressor.unconsumed_tail

@property
def unused_data(self) -> bytes:
return self._decompressor.unused_data


class BrotliDecompressor:
# Supports both 'brotlipy' and 'Brotli' packages
Expand Down
3 changes: 1 addition & 2 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import sys
import zlib
from typing import ( # noqa
Any,
Awaitable,
Expand Down Expand Up @@ -85,7 +84,7 @@ def enable_chunking(self) -> None:
self.chunked = True

def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
self, encoding: str = "deflate", strategy: Optional[int] = None
) -> None:
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)

Expand Down
3 changes: 1 addition & 2 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import uuid
import warnings
import zlib
from collections import deque
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -1032,7 +1031,7 @@ def enable_encoding(self, encoding: str) -> None:
self._encoding = "quoted-printable"

def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
self, encoding: str = "deflate", strategy: Optional[int] = None
) -> None:
self._compress = ZLibCompressor(
encoding=encoding,
Expand Down
5 changes: 2 additions & 3 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import math
import time
import warnings
import zlib
from concurrent.futures import Executor
from http import HTTPStatus
from typing import (
Expand Down Expand Up @@ -83,7 +82,7 @@ class StreamResponse(BaseClass, HeadersMixin, CookieMixin):
_keep_alive: Optional[bool] = None
_chunked: bool = False
_compression: bool = False
_compression_strategy: int = zlib.Z_DEFAULT_STRATEGY
_compression_strategy: Optional[int] = None
_compression_force: Optional[ContentCoding] = None
_req: Optional["BaseRequest"] = None
_payload_writer: Optional[AbstractStreamWriter] = None
Expand Down Expand Up @@ -184,7 +183,7 @@ def enable_chunked_encoding(self) -> None:
def enable_compression(
self,
force: Optional[ContentCoding] = None,
strategy: int = zlib.Z_DEFAULT_STRATEGY,
strategy: Optional[int] = None,
) -> None:
"""Enables response compression encoding."""
self._compression = True
Expand Down
24 changes: 24 additions & 0 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,30 @@ Utilities

.. versionadded:: 3.0

.. function:: set_zlib_backend(lib)

Sets the compression backend for zlib-based operations.

This function allows you to override the default zlib backend
used internally by passing a module that implements the standard
compression interface.

The module should implement at minimum the exact interface offered by the
latest version of zlib.

:param types.ModuleType lib: A module that implements the zlib-compatible compression API.

Example usage::

import zlib_ng.zlib_ng as zng
import aiohttp

aiohttp.set_zlib_backend(zng)

.. note:: aiohttp has been tested internally with :mod:`zlib`, :mod:`zlib_ng.zlib_ng`, and :mod:`isal.isal_zlib`.

.. versionadded:: 3.12

FormData
^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/latest/", None),
"isal": ("https://python-isal.readthedocs.io/en/stable/", None),
"zlib_ng": ("https://python-zlib-ng.readthedocs.io/en/stable/", None),
}

# Add any paths that contain templates here, relative to this directory.
Expand Down
Loading
Loading