Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 45 additions & 285 deletions mapillary_tools/api_v4.py

Large diffs are not rendered by default.

47 changes: 23 additions & 24 deletions mapillary_tools/authenticate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import requests

from . import api_v4, config, constants, exceptions
from . import api_v4, config, constants, exceptions, http


LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,11 +77,11 @@ def authenticate(
# TODO: print more user information
if profile_name in all_user_items:
LOG.info(
'Profile "%s" updated: %s', profile_name, api_v4._sanitize(user_items)
'Profile "%s" updated: %s', profile_name, http._sanitize(user_items)
)
else:
LOG.info(
'Profile "%s" created: %s', profile_name, api_v4._sanitize(user_items)
'Profile "%s" created: %s', profile_name, http._sanitize(user_items)
)


Expand Down Expand Up @@ -134,9 +134,8 @@ def fetch_user_items(
)

if organization_key is not None:
resp = api_v4.fetch_organization(
user_items["user_upload_token"], organization_key
)
with api_v4.create_user_session(user_items["user_upload_token"]) as session:
resp = api_v4.fetch_organization(session, organization_key)
data = api_v4.jsonify_response(resp)
LOG.info(
f"Uploading to organization: {data.get('name')} (ID: {data.get('id')})"
Expand Down Expand Up @@ -173,16 +172,15 @@ def _verify_user_auth(user_items: config.UserItem) -> config.UserItem:
if constants._AUTH_VERIFICATION_DISABLED:
return user_items

try:
resp = api_v4.fetch_user_or_me(
user_access_token=user_items["user_upload_token"]
)
except requests.HTTPError as ex:
if api_v4.is_auth_error(ex.response):
message = api_v4.extract_auth_error_message(ex.response)
raise exceptions.MapillaryUploadUnauthorizedError(message)
else:
raise ex
with api_v4.create_user_session(user_items["user_upload_token"]) as session:
try:
resp = api_v4.fetch_user_or_me(session)
except requests.HTTPError as ex:
if api_v4.is_auth_error(ex.response):
message = api_v4.extract_auth_error_message(ex.response)
raise exceptions.MapillaryUploadUnauthorizedError(message)
else:
raise ex

data = api_v4.jsonify_response(resp)

Expand Down Expand Up @@ -276,16 +274,17 @@ def _prompt_login(
if user_password:
break

try:
resp = api_v4.get_upload_token(user_email, user_password)
except requests.HTTPError as ex:
if not _enabled:
raise ex
with api_v4.create_client_session() as session:
try:
resp = api_v4.get_upload_token(session, user_email, user_password)
except requests.HTTPError as ex:
if not _enabled:
raise ex

if _is_login_retryable(ex):
return _prompt_login()
if _is_login_retryable(ex):
return _prompt_login()

raise ex
raise ex

data = api_v4.jsonify_response(resp)

Expand Down
7 changes: 7 additions & 0 deletions mapillary_tools/commands/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def add_common_upload_options(group):
default=None,
required=False,
)
group.add_argument(
"--num_upload_workers",
help="Number of concurrent upload workers for uploading images. [default: %(default)s]",
default=constants.MAX_IMAGE_UPLOAD_WORKERS,
type=int,
required=False,
)
group.add_argument(
"--reupload",
help="Re-upload data that has already been uploaded.",
Expand Down
8 changes: 5 additions & 3 deletions mapillary_tools/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,18 @@ def _parse_scaled_integers(
# The minimal upload speed is used to calculate the read timeout to avoid upload hanging:
# timeout = upload_size / MIN_UPLOAD_SPEED
MIN_UPLOAD_SPEED: int | None = _parse_filesize(
os.getenv(_ENV_PREFIX + "MIN_UPLOAD_SPEED", "50K") # 50 KiB/s
os.getenv(_ENV_PREFIX + "MIN_UPLOAD_SPEED", "50K") # 50 Kb/s
)
# Maximum number of parallel workers for uploading images within a single sequence.
# NOTE: Sequences themselves are uploaded sequentially, not in parallel.
MAX_IMAGE_UPLOAD_WORKERS: int = int(
os.getenv(_ENV_PREFIX + "MAX_IMAGE_UPLOAD_WORKERS", 64)
os.getenv(_ENV_PREFIX + "MAX_IMAGE_UPLOAD_WORKERS", 4)
)
# The chunk size in MB (see chunked transfer encoding https://en.wikipedia.org/wiki/Chunked_transfer_encoding)
# for uploading data to MLY upload service.
# Changing this size does not change the number of requests nor affect upload performance,
# but it affects the responsiveness of the upload progress bar
UPLOAD_CHUNK_SIZE_MB: float = float(os.getenv(_ENV_PREFIX + "UPLOAD_CHUNK_SIZE_MB", 1))
UPLOAD_CHUNK_SIZE_MB: float = float(os.getenv(_ENV_PREFIX + "UPLOAD_CHUNK_SIZE_MB", 2))
MAX_UPLOAD_RETRIES: int = int(os.getenv(_ENV_PREFIX + "MAX_UPLOAD_RETRIES", 200))
MAPILLARY__ENABLE_UPLOAD_HISTORY_FOR_DRY_RUN: bool = _yes_or_no(
os.getenv("MAPILLARY__ENABLE_UPLOAD_HISTORY_FOR_DRY_RUN", "NO")
Expand Down
211 changes: 211 additions & 0 deletions mapillary_tools/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from __future__ import annotations

import logging

import ssl
import sys
import typing as T
from json import dumps

if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override

import requests
from requests.adapters import HTTPAdapter


LOG = logging.getLogger(__name__)


class HTTPSystemCertsAdapter(HTTPAdapter):
"""
This adapter uses the system's certificate store instead of the certifi module.

The implementation is based on the project https://pypi.org/project/pip-system-certs/,
which has a system-wide effect.
"""

def init_poolmanager(self, *args, **kwargs):
ssl_context = ssl.create_default_context()
ssl_context.load_default_certs()
kwargs["ssl_context"] = ssl_context

super().init_poolmanager(*args, **kwargs)

def cert_verify(self, *args, **kwargs):
super().cert_verify(*args, **kwargs)

# By default Python requests uses the ca_certs from the certifi module
# But we want to use the certificate store instead.
# By clearing the ca_certs variable we force it to fall back on that behaviour (handled in urllib3)
if "conn" in kwargs:
conn = kwargs["conn"]
else:
conn = args[0]

conn.ca_certs = None


class Session(requests.Session):
# NOTE: This is a global flag that affects all Session instances
USE_SYSTEM_CERTS: T.ClassVar[bool] = False
# Instance variables
disable_logging_request: bool = False
disable_logging_response: bool = False
# Avoid mounting twice
_mounted: bool = False

@override
def request(self, method: str | bytes, url: str | bytes, *args, **kwargs):
self._log_debug_request(method, url, *args, **kwargs)

if Session.USE_SYSTEM_CERTS:
if not self._mounted:
self.mount("https://", HTTPSystemCertsAdapter())
self._mounted = True
resp = super().request(method, url, *args, **kwargs)
else:
try:
resp = super().request(method, url, *args, **kwargs)
except requests.exceptions.SSLError as ex:
if "SSLCertVerificationError" not in str(ex):
raise ex
Session.USE_SYSTEM_CERTS = True
# HTTPSConnectionPool(host='graph.mapillary.com', port=443): Max retries exceeded with url: /login (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1018)')))
LOG.warning(
"SSL error occurred, falling back to system SSL certificates: %s",
ex,
)
return self.request(method, url, *args, **kwargs)

self._log_debug_response(resp)

return resp

def _log_debug_request(self, method: str | bytes, url: str | bytes, **kwargs):
if self.disable_logging_request:
return

if logging.getLogger().getEffectiveLevel() <= logging.DEBUG:
return

if isinstance(method, str) and isinstance(url, str):
msg = f"HTTP {method} {url}"
else:
msg = f"HTTP {method!r} {url!r}"

if Session.USE_SYSTEM_CERTS:
msg += " (w/sys_certs)"

json = kwargs.get("json")
if json is not None:
t = _truncate(dumps(_sanitize(json)))
msg += f" JSON={t}"

params = kwargs.get("params")
if params is not None:
msg += f" PARAMS={_sanitize(params)}"

headers = kwargs.get("headers")
if headers is not None:
msg += f" HEADERS={_sanitize(headers)}"

timeout = kwargs.get("timeout")
if timeout is not None:
msg += f" TIMEOUT={timeout}"

msg = msg.replace("\n", "\\n")

LOG.debug(msg)

def _log_debug_response(self, resp: requests.Response):
if self.disable_logging_response:
return

if logging.getLogger().getEffectiveLevel() <= logging.DEBUG:
return

elapsed = resp.elapsed.total_seconds() * 1000 # Convert to milliseconds
msg = f"HTTP {resp.status_code} {resp.reason} ({elapsed:.0f} ms): {str(_truncate_response_content(resp))}"

LOG.debug(msg)


def readable_http_error(ex: requests.HTTPError) -> str:
return readable_http_response(ex.response)


def readable_http_response(resp: requests.Response) -> str:
return f"{resp.request.method} {resp.url} => {resp.status_code} {resp.reason}: {str(_truncate_response_content(resp))}"


@T.overload
def _truncate(s: bytes, limit: int = 256) -> bytes | str: ...


@T.overload
def _truncate(s: str, limit: int = 256) -> str: ...


def _truncate(s, limit=256):
if limit < len(s):
if isinstance(s, bytes):
try:
s = s.decode("utf-8")
except UnicodeDecodeError:
pass
remaining = len(s) - limit
if isinstance(s, bytes):
return s[:limit] + f"...({remaining} bytes truncated)".encode("utf-8")
else:
return str(s[:limit]) + f"...({remaining} chars truncated)"
else:
return s


def _sanitize(headers: T.Mapping[T.Any, T.Any]) -> T.Mapping[T.Any, T.Any]:
new_headers = {}

for k, v in headers.items():
if k.lower() in [
"authorization",
"cookie",
"x-fb-access-token",
"access-token",
"access_token",
"password",
"user_upload_token",
]:
new_headers[k] = "[REDACTED]"
else:
if isinstance(v, (str, bytes)):
new_headers[k] = T.cast(T.Any, _truncate(v))
else:
new_headers[k] = v

return new_headers


def _truncate_response_content(resp: requests.Response) -> str | bytes:
try:
json_data = resp.json()
except requests.JSONDecodeError:
if resp.content is not None:
data = _truncate(resp.content)
else:
data = ""
else:
if isinstance(json_data, dict):
data = _truncate(dumps(_sanitize(json_data)))
else:
data = _truncate(str(json_data))

if isinstance(data, bytes):
return data.replace(b"\n", b"\\n")

elif isinstance(data, str):
return data.replace("\n", "\\n")

return data
Loading
Loading