Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
474 changes: 474 additions & 0 deletions functions-python/helpers/tests/test_feed_http.py

Large diffs are not rendered by default.

319 changes: 287 additions & 32 deletions functions-python/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import logging
import os
import ssl
from datetime import date, datetime
import time
import urllib3.exceptions
from datetime import date, datetime, timezone
from logging import Logger
from typing import Optional

Expand Down Expand Up @@ -121,6 +123,283 @@ def get_hash_from_file(file_path, hash_algorithm="sha256", chunk_size=8192):
return hash_object.hexdigest()


def create_feed_ssl_context(trusted_certs: bool = False):
"""
Create a urllib3 SSL context suitable for GTFS feed HTTP requests.

Enables legacy server connect (ssl.OP_LEGACY_SERVER_CONNECT) to handle
servers with DH key issues. When trusted_certs=True, hostname verification
and certificate validation are disabled (use only for known problematic feeds).
"""
ctx = create_urllib3_context()
ctx.load_default_certs()
# This is the only way to make urllib3 work with legacy servers
# More information: https://github.com/urllib3/urllib3/issues/2653#issuecomment-1165418616
ctx.options |= 0x4 # ssl.OP_LEGACY_SERVER_CONNECT
if trusted_certs:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx


def build_feed_request_params(
url: str,
feed_id: Optional[str] = None,
authentication_type=0,
api_key_parameter_name: Optional[str] = None,
credentials: Optional[str] = None,
) -> tuple:
"""
Build HTTP request headers and resolve the final URL for a feed request.

Handles:
- Per-feed User-Agent overrides via config DB (feed_download/http_headers)
- Default mobile browser User-Agent + Referer fallback
- Auth type 1: API key appended as a URL query parameter
- Auth type 2: API key injected as a request header

Returns:
(headers, resolved_url) ready to pass to any HTTP method.
"""
from shared.common.config_reader import get_config_value

headers = get_config_value(
namespace="feed_download", key="http_headers", feed_id=feed_id
)
if headers is None:
headers = {
"User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/126.0.0.0 Mobile Safari/537.36",
"Referer": url,
}

auth_type = int(authentication_type) if authentication_type is not None else 0

# authentication_type == 1 -> the credentials are passed in the url
# Careful, some URLs may already contain a query string
# (e.g. http://api.511.org/transit/datafeeds?operator_id=CE)
if auth_type == 1 and api_key_parameter_name and credentials:
separator = "&" if "?" in url else "?"
url += f"{separator}{api_key_parameter_name}={credentials}"

# authentication_type == 2 -> the credentials are passed in the header
if auth_type == 2 and api_key_parameter_name and credentials:
headers[api_key_parameter_name] = credentials

return headers, url


_ZIP_CONTENT_TYPES = frozenset(
{
"application/zip",
"application/x-zip",
"application/x-zip-compressed",
"application/gtfs+zip",
}
)
_ZIP_MAGIC = b"\x50\x4b\x03\x04" # PK\x03\x04 — ZIP local file header signature


def _parse_content_type(raw: Optional[str]) -> Optional[str]:
"""Return the normalised MIME type from a raw Content-Type header, or None."""
if not raw:
return None
return raw.split(";")[0].strip().lower()


def _is_zip_from_content_type(content_type: Optional[str]) -> Optional[bool]:
"""Infer is_zip from a normalised Content-Type string.

Returns True/False for known types, None for ambiguous ones
(e.g. application/octet-stream) where magic-byte verification is needed.
"""
if content_type is None:
return None
if content_type in _ZIP_CONTENT_TYPES:
return True
if content_type == "application/octet-stream":
return None # ambiguous — caller should verify via magic bytes
return False


def _log_redirects(stable_id: str, producer_url: str, redirect_urls: list) -> None:
"""Log redirect URLs that differ from the original producer_url."""
unique = [u for u in dict.fromkeys(redirect_urls) if u != producer_url]
if unique:
logging.info(
"Feed %s (%s) redirected through: %s", stable_id, producer_url, unique
)


def _execute_http_request(
method: str,
url: str,
headers: Optional[dict],
timeout_seconds: int,
read_bytes: int = 0,
) -> tuple:
"""Execute a single HTTP request and return a result tuple.

Returns:
(status_code, latency_ms, resp_headers, first_bytes, error_type, error_message, redirect_urls)
redirect_urls is a list of URLs the request was redirected through.
On network/timeout errors, status_code/latency_ms/resp_headers are None,
first_bytes is b'', and redirect_urls is [].
"""
preload = read_bytes == 0
try:
ctx = create_feed_ssl_context()
retries = urllib3.Retry(redirect=10, connect=1, read=0, status=0)
with urllib3.PoolManager(ssl_context=ctx) as http:
start = time.monotonic()
r = http.request(
method,
url,
headers=headers,
retries=retries,
preload_content=preload,
timeout=urllib3.Timeout(connect=timeout_seconds, read=timeout_seconds),
)
latency_ms = int((time.monotonic() - start) * 1000)
status_code = r.status
resp_headers = r.headers
first_bytes = r.read(read_bytes) if not preload else b""
if not preload:
r.release_conn()
redirect_urls = [
h.redirect_location
for h in (r.retries.history or [])
if h.redirect_location
]
return (
status_code,
latency_ms,
resp_headers,
first_bytes,
None,
None,
redirect_urls,
)
except urllib3.exceptions.MaxRetryError as exc:
return None, None, None, b"", "ConnectionError", str(exc), []
except urllib3.exceptions.TimeoutError as exc:
return None, None, None, b"", "Timeout", str(exc), []
except urllib3.exceptions.HTTPError as exc:
return None, None, None, b"", type(exc).__name__, str(exc), []


def perform_request(
feed_id: str,
stable_id: str,
producer_url: str,
authentication_type: str,
api_key_parameter_name: Optional[str],
credentials: Optional[str],
timeout_seconds: int,
fallback_to_get: bool = False,
):
"""Execute an HTTP HEAD (with optional GET fallback) for a feed availability check.

Tries HEAD first. When fallback_to_get=True and HEAD fails (non-2xx or any
exception), retries with a lightweight GET that reads only 4 bytes to detect
the ZIP magic signature (PK\\x03\\x04). The stored request_type reflects which
method produced the final result.

Note: request_url is always the original producer_url (never the
credential-bearing resolved URL) to avoid persisting secrets.
"""
from shared.database_gen.sqlacodegen_models import GtfsFeedAvailabilityCheck

checked_at = datetime.now(timezone.utc)
headers, resolved_url = build_feed_request_params(
producer_url,
feed_id=feed_id,
authentication_type=authentication_type,
api_key_parameter_name=api_key_parameter_name,
credentials=credentials,
)

(
status_code,
latency_ms,
resp_headers,
_,
error_type,
error_message,
redirect_urls,
) = _execute_http_request("HEAD", resolved_url, headers, timeout_seconds)
request_type = "http_head"
success = status_code is not None and status_code < 400
content_type = _parse_content_type(
resp_headers.get("Content-Type") if resp_headers else None
)
is_zip = _is_zip_from_content_type(content_type)

if error_type:
logging.warning(
"HEAD %s for feed %s (%s): %s",
error_type,
stable_id,
producer_url,
error_message,
)
_log_redirects(stable_id, producer_url, redirect_urls)

if not success and fallback_to_get:
logging.info(
"HEAD failed for feed %s (%s) [status=%s error=%s], trying GET fallback",
stable_id,
producer_url,
status_code,
error_type,
)
(
status_code,
latency_ms,
resp_headers,
first_bytes,
error_type,
error_message,
redirect_urls,
) = _execute_http_request(
"GET", resolved_url, headers, timeout_seconds, read_bytes=4
)
request_type = "http_get"
success = status_code is not None and status_code < 400
content_type = _parse_content_type(
resp_headers.get("Content-Type") if resp_headers else None
)
is_zip = (
first_bytes == _ZIP_MAGIC
if first_bytes
else _is_zip_from_content_type(content_type)
)
if error_type:
logging.warning(
"GET fallback %s for feed %s (%s): %s",
error_type,
stable_id,
producer_url,
error_message,
)
_log_redirects(stable_id, producer_url, redirect_urls)

return GtfsFeedAvailabilityCheck(
feed_id=feed_id,
checked_at=checked_at,
request_url=producer_url,
request_type=request_type,
status_code=status_code,
latency_ms=latency_ms,
error_message=error_message,
error_type=error_type,
success=success,
content_type=content_type,
is_zip=is_zip,
)


def download_and_get_hash(
url,
file_path,
Expand All @@ -133,46 +412,22 @@ def download_and_get_hash(
logger=None,
trusted_certs=False, # If True, disables SSL verification
):
from shared.common.config_reader import get_config_value

"""
Downloads the content of a URL and stores it in a file and returns the hash of the file
"""
logger = logger or logging.getLogger(__name__)
try:
hash_object = hashlib.new(hash_algorithm)

# This the only way to make urllib3 work with legacy servers
# More information: https://github.com/urllib3/urllib3/issues/2653#issuecomment-1165418616
ctx = create_urllib3_context()
ctx.load_default_certs()
ctx.options |= 0x4 # ssl.OP_LEGACY_SERVER_CONNECT
ctx = create_feed_ssl_context(trusted_certs=trusted_certs)

headers = get_config_value(
namespace="feed_download", key="http_headers", feed_id=feed_id
headers, url = build_feed_request_params(
url,
feed_id=feed_id,
authentication_type=authentication_type,
api_key_parameter_name=api_key_parameter_name,
credentials=credentials,
)
if headers is None:
headers = {
"User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/126.0.0.0 Mobile Safari/537.36",
"Referer": url,
}

# authentication_type == 1 -> the credentials are passed in the url
# Careful, some URLs may already contain a query string
# (e.g. http://api.511.org/transit/datafeeds?operator_id=CE)
if authentication_type == 1 and api_key_parameter_name and credentials:
separator = "&" if "?" in url else "?"
url += f"{separator}{api_key_parameter_name}={credentials}"

# authentication_type == 2 -> the credentials are passed in the header
if authentication_type == 2 and api_key_parameter_name and credentials:
headers[api_key_parameter_name] = credentials

if trusted_certs:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

with urllib3.PoolManager(ssl_context=ctx) as http:
with http.request(
Expand Down
Loading
Loading