Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changes
=======

0.9.0 (unreleased)
------------------

* Added an opt-in ``trust_env`` parameter to :class:`~zyte_api.AsyncZyteAPI`
and :class:`~zyte_api.ZyteAPI`, and an opt-in ``--trust-env`` CLI flag, to
allow honoring environment-based network settings (e.g. ``HTTP_PROXY`` and
``HTTPS_PROXY``).

0.8.2 (2026-02-10)
------------------

Expand Down
20 changes: 19 additions & 1 deletion tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import asyncio
from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch

import pytest

from zyte_api import AggressiveRetryFactory, AsyncZyteAPI, RequestError
from zyte_api._utils import create_session
from zyte_api.aio.client import AsyncClient
from zyte_api.apikey import NoApiKey
from zyte_api.errors import ParsedError
Expand Down Expand Up @@ -54,6 +55,23 @@ def test_api_key(client_cls):
client_cls()


@pytest.mark.asyncio
async def test_session_inherits_client_trust_env(mockserver):
client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"), trust_env=True)
async with client.session() as session:
assert session._session._trust_env is True


@pytest.mark.asyncio
async def test_get_creates_session_with_client_trust_env(mockserver):
client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"), trust_env=True)
with patch(
"zyte_api._async.create_session", wraps=create_session
) as create_session_mock:
await client.get({"url": "https://a.example"})
assert create_session_mock.call_args.kwargs["trust_env"] is True


@pytest.mark.parametrize(
("client_cls", "get_method"),
(
Expand Down
7 changes: 6 additions & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@


def run_zyte_api(args, env, mockserver):
base_env = {
key: value
for key, value in environ.items()
if key not in {"ZYTE_API_KEY", "ZYTE_API_ETH_KEY"}
}
with NamedTemporaryFile("w") as url_list:
url_list.write("https://a.example\n")
url_list.flush()
Expand All @@ -29,7 +34,7 @@ def run_zyte_api(args, env, mockserver):
],
capture_output=True,
check=False,
env={**environ, **env},
env={**base_env, **env},
)


Expand Down
16 changes: 15 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

from zyte_api import RequestError
from zyte_api.__main__ import run
from zyte_api.__main__ import _get_argument_parser, run

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -108,6 +108,7 @@ async def test_run(queries, expected_response, store_errors, exception):
api_url = "https://example.com"
api_key = "fake_key"
retry_errors = True
trust_env = True

# Create a mock for AsyncZyteAPI
async_client_mock = Mock()
Expand Down Expand Up @@ -138,8 +139,15 @@ async def test_run(queries, expected_response, store_errors, exception):
api_key=api_key,
retry_errors=retry_errors,
store_errors=store_errors,
trust_env=trust_env,
)

assert async_client_mock.call_args.kwargs["trust_env"] is True
create_session_mock.assert_called_once_with(
connection_pool_size=n_conn,
trust_env=True,
)

assert get_json_content(temporary_file) == expected_response
tmp_path.unlink()

Expand Down Expand Up @@ -218,6 +226,12 @@ def test_empty_input(mockserver):
assert result.stderr == b"No input queries found. Is the input file empty?\n"


def test_trust_env_flag_parsing() -> None:
parser = _get_argument_parser()
args = parser.parse_args(["--trust-env", "--api-key", "a", "README.rst"])
assert args.trust_env is True


def test_intype_txt_implicit(mockserver):
result = _run(input_="https://a.example", mockserver=mockserver)
assert not result.returncode
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from types import GeneratorType
from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch

import pytest

Expand All @@ -19,6 +19,12 @@ def test_api_key():
ZyteAPI()


def test_trust_env_is_forwarded():
with patch("zyte_api._sync.AsyncZyteAPI") as async_client:
ZyteAPI(api_key="a", trust_env=True)
assert async_client.call_args.kwargs["trust_env"] is True


def test_get(mockserver):
client = ZyteAPI(api_key="a", api_url=mockserver.urljoin("/"))
expected_result = {
Expand Down
18 changes: 17 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@ async def test_create_session_custom_connector():
custom_connector = TCPConnector(limit=1850)
session = create_session(connector=custom_connector)
assert session.connector == custom_connector
await session.close()


@pytest.mark.asyncio
async def test_create_session_trust_env_disabled_by_default():
session = create_session()
assert session._trust_env is False
await session.close()


@pytest.mark.asyncio
async def test_create_session_trust_env_can_be_enabled():
session = create_session(trust_env=True)
assert session._trust_env is True
await session.close()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -121,4 +136,5 @@ async def test_deprecated_create_session():
DeprecationWarning,
match=r"^zyte_api\.aio\.client\.create_session is deprecated",
):
_create_session()
session = _create_session()
await session.close()
68 changes: 48 additions & 20 deletions zyte_api/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import logging
import random
import sys
from contextlib import nullcontext
from pathlib import Path
from typing import IO, Any, Literal
from warnings import warn

Expand Down Expand Up @@ -42,6 +44,7 @@ async def run(
retry_errors: bool = True,
store_errors: bool | None = None,
eth_key: str | None = None,
trust_env: bool = False,
) -> None:
if stop_on_errors is not _UNSET:
warn(
Expand All @@ -65,9 +68,15 @@ def write_output(content: Any) -> None:
elif eth_key:
auth_kwargs["eth_key"] = eth_key
client = AsyncZyteAPI(
n_conn=n_conn, api_url=api_url, retrying=retrying, **auth_kwargs
n_conn=n_conn,
api_url=api_url,
retrying=retrying,
trust_env=trust_env,
**auth_kwargs,
)
async with create_session(connection_pool_size=n_conn) as session:
async with create_session(
connection_pool_size=n_conn, trust_env=trust_env
) as session:
result_iter = client.iter(
queries=queries,
session=session,
Expand Down Expand Up @@ -128,7 +137,6 @@ def _get_argument_parser(program_name: str = "zyte-api") -> argparse.ArgumentPar
)
p.add_argument(
"INPUT",
type=argparse.FileType("r", encoding="utf8"),
help=(
"Path to an input file (see 'Command-line client > Input file' in "
"the docs for details)."
Expand All @@ -151,8 +159,7 @@ def _get_argument_parser(program_name: str = "zyte-api") -> argparse.ArgumentPar
p.add_argument(
"--output",
"-o",
default=sys.stdout,
type=argparse.FileType("w", encoding="utf8"),
default=None,
help=(
"Path for the output file. Results are written into the output "
"file in JSON Lines format.\n"
Expand Down Expand Up @@ -225,6 +232,14 @@ def _get_argument_parser(program_name: str = "zyte-api") -> argparse.ArgumentPar
),
action="store_true",
)
p.add_argument(
"--trust-env",
help=(
"Enable environment-based network settings such as HTTP_PROXY and "
"HTTPS_PROXY for Zyte API requests."
),
action="store_true",
)
return p


Expand All @@ -234,7 +249,15 @@ def _main(program_name: str = "zyte-api") -> None:
args = p.parse_args()
logging.basicConfig(stream=sys.stderr, level=getattr(logging, args.loglevel))

queries = read_input(args.INPUT, args.intype)
if args.INPUT == "-":
with nullcontext(sys.stdin) as input_fp:
queries = read_input(input_fp, args.intype)
else:
try:
with Path(args.INPUT).open(encoding="utf8") as input_fp:
queries = read_input(input_fp, args.intype)
except OSError as e:
p.error(f"Cannot open input file {args.INPUT!r}: {e}")
if not queries:
print("No input queries found. Is the input file empty?", file=sys.stderr)
sys.exit(-1)
Expand All @@ -245,23 +268,28 @@ def _main(program_name: str = "zyte-api") -> None:
queries = queries[: args.limit]

logger.info(
f"Loaded {len(queries)} urls from {args.INPUT.name}; shuffled: {args.shuffle}"
f"Loaded {len(queries)} urls from {args.INPUT}; shuffled: {args.shuffle}"
)
logger.info(f"Running Zyte API (connections: {args.n_conn})")

loop = asyncio.get_event_loop()
coro = run(
queries,
out=args.output,
n_conn=args.n_conn,
api_url=args.api_url,
api_key=args.api_key,
eth_key=args.eth_key,
retry_errors=not args.dont_retry_errors,
store_errors=args.store_errors,
)
loop.run_until_complete(coro)
loop.close()
run_kwargs = {
"n_conn": args.n_conn,
"api_url": args.api_url,
"api_key": args.api_key,
"eth_key": args.eth_key,
"retry_errors": not args.dont_retry_errors,
"store_errors": args.store_errors,
"trust_env": args.trust_env,
}
if args.output is None or args.output == "-":
with nullcontext(sys.stdout) as out:
asyncio.run(run(queries, out=out, **run_kwargs))
else:
try:
with Path(args.output).open("w", encoding="utf8") as out:
asyncio.run(run(queries, out=out, **run_kwargs))
except OSError as e:
p.error(f"Cannot open output file {args.output!r}: {e}")


if __name__ == "__main__":
Expand Down
27 changes: 19 additions & 8 deletions zyte_api/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _post_func(
class _AsyncSession:
def __init__(self, client: AsyncZyteAPI, **session_kwargs: Any):
self._client: AsyncZyteAPI = client
session_kwargs.setdefault("trust_env", client.trust_env)
self._session: aiohttp.ClientSession = create_session(
client.n_conn, **session_kwargs
)
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(
retrying: AsyncRetrying | None = None,
user_agent: str | None = None,
eth_key: str | None = None,
trust_env: bool = False,
):
if retrying is not None and not isinstance(retrying, AsyncRetrying):
raise ValueError(
Expand All @@ -134,6 +136,7 @@ def __init__(
self.agg_stats = AggStats()
self.retrying = retrying or zyte_api_retrying
self.user_agent = user_agent or USER_AGENT
self.trust_env = trust_env
self._semaphore = asyncio.Semaphore(n_conn)
self._auth: str | _x402Handler
self.auth: AuthInfo
Expand Down Expand Up @@ -190,6 +193,10 @@ async def get(
) -> dict[str, Any]:
"""Asynchronous equivalent to :meth:`ZyteAPI.get`."""
retrying = retrying or self.retrying
owned_session: aiohttp.ClientSession | None = None
if session is None:
owned_session = create_session(self.n_conn, trust_env=self.trust_env)
session = owned_session
post = _post_func(session)

url = self.api_url + endpoint
Expand Down Expand Up @@ -257,14 +264,18 @@ async def request() -> dict[str, Any]:
request = retrying.wraps(request)

try:
# Try to make a request
result = await request()
self.agg_stats.n_success += 1
except Exception:
self.agg_stats.n_fatal_errors += 1
raise

return result
try:
# Try to make a request
result = await request()
self.agg_stats.n_success += 1
except Exception:
self.agg_stats.n_fatal_errors += 1
raise

return result
finally:
if owned_session is not None:
await owned_session.close()

def iter(
self,
Expand Down
6 changes: 6 additions & 0 deletions zyte_api/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class ZyteAPI:
*user_agent* is the user agent string reported to Zyte API. Defaults to
``python-zyte-api/<VERSION>``.

*trust_env* controls whether :mod:`aiohttp` honors environment-based
network settings (e.g. ``HTTP_PROXY`` and ``HTTPS_PROXY``). Defaults to
``False``.

.. tip:: To change the ``User-Agent`` header sent to a target website, use
:http:`request:customHttpRequestHeaders` instead.
"""
Expand All @@ -117,6 +121,7 @@ def __init__(
retrying: AsyncRetrying | None = None,
user_agent: str | None = None,
eth_key: str | None = None,
trust_env: bool = False,
):
self._async_client = AsyncZyteAPI(
api_key=api_key,
Expand All @@ -125,6 +130,7 @@ def __init__(
retrying=retrying,
user_agent=user_agent,
eth_key=eth_key,
trust_env=trust_env,
)

def get(
Expand Down
Loading