Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 87 additions & 57 deletions mapillary_tools/history.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
from __future__ import annotations

import contextlib
import dbm
import json
import logging
import os
import sqlite3
import string
import threading
import time
import typing as T
from functools import wraps
from pathlib import Path

# dbm modules are dynamically imported, so here we explicitly import dbm.sqlite3 to make sure pyinstaller include it
# Otherwise you will see: ImportError: no dbm clone found; tried ['dbm.sqlite3', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb']
try:
import dbm.sqlite3 # type: ignore
except ImportError:
pass


from . import constants, types
from . import constants, store, types
from .serializer.description import DescriptionJSONSerializer

JSONDict = T.Dict[str, T.Union[str, int, float, None]]
Expand Down Expand Up @@ -85,103 +78,140 @@ def write_history(
fp.write(json.dumps(history))


def _retry_on_database_lock_error(fn):
"""
Decorator to retry a function if it raises a sqlite3.OperationalError with
"database is locked" in the message.
"""

@wraps(fn)
def wrapper(*args, **kwargs):
while True:
try:
return fn(*args, **kwargs)
except sqlite3.OperationalError as ex:
if "database is locked" in str(ex).lower():
LOG.warning(f"{str(ex)}")
LOG.info("Retrying in 1 second...")
time.sleep(1)
else:
raise ex

return wrapper


class PersistentCache:
_lock: contextlib.nullcontext | threading.Lock
_lock: threading.Lock

def __init__(self, file: str):
# SQLite3 backend supports concurrent access without a lock
if dbm.whichdb(file) == "dbm.sqlite3":
self._lock = contextlib.nullcontext()
else:
self._lock = threading.Lock()
self._file = file
self._lock = threading.Lock()

def get(self, key: str) -> str | None:
if not self._db_existed():
return None

s = time.perf_counter()

with self._lock:
with dbm.open(self._file, flag="c") as db:
value: bytes | None = db.get(key)
with store.KeyValueStore(self._file, flag="r") as db:
try:
raw_payload: bytes | None = db.get(key) # data retrieved from db[key]
except Exception as ex:
if self._table_not_found(ex):
return None
raise ex

if value is None:
if raw_payload is None:
return None

payload = self._decode(value)
data: JSONDict = self._decode(raw_payload) # JSON dict decoded from db[key]

if self._is_expired(payload):
if self._is_expired(data):
return None

file_handle = payload.get("file_handle")
cached_value = data.get("value") # value in the JSON dict decoded from db[key]

LOG.debug(
f"Found file handle for {key} in cache ({(time.perf_counter() - s) * 1000:.0f} ms)"
)

return T.cast(str, file_handle)
return T.cast(str, cached_value)

def set(self, key: str, file_handle: str, expires_in: int = 3600 * 24 * 2) -> None:
@_retry_on_database_lock_error
def set(self, key: str, value: str, expires_in: int = 3600 * 24 * 2) -> None:
s = time.perf_counter()

payload = {
data = {
"expires_at": time.time() + expires_in,
"file_handle": file_handle,
"value": value,
}

value: bytes = json.dumps(payload).encode("utf-8")
payload: bytes = json.dumps(data).encode("utf-8")

with self._lock:
with dbm.open(self._file, flag="c") as db:
db[key] = value
with store.KeyValueStore(self._file, flag="c") as db:
db[key] = payload

LOG.debug(
f"Cached file handle for {key} ({(time.perf_counter() - s) * 1000:.0f} ms)"
)

@_retry_on_database_lock_error
def clear_expired(self) -> list[str]:
s = time.perf_counter()

expired_keys: list[str] = []

with self._lock:
with dbm.open(self._file, flag="c") as db:
if hasattr(db, "items"):
items: T.Iterable[tuple[str | bytes, bytes]] = db.items()
else:
items = ((key, db[key]) for key in db.keys())
s = time.perf_counter()

for key, value in items:
payload = self._decode(value)
if self._is_expired(payload):
with self._lock:
with store.KeyValueStore(self._file, flag="c") as db:
for key, raw_payload in db.items():
data = self._decode(raw_payload)
if self._is_expired(data):
del db[key]
expired_keys.append(T.cast(str, key))

if expired_keys:
LOG.debug(
f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)"
)
LOG.debug(
f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)"
)

return expired_keys

def keys(self):
with self._lock:
with dbm.open(self._file, flag="c") as db:
return db.keys()
def keys(self) -> list[str]:
if not self._db_existed():
return []

def _is_expired(self, payload: JSONDict) -> bool:
expires_at = payload.get("expires_at")
try:
with store.KeyValueStore(self._file, flag="r") as db:
return [key.decode("utf-8") for key in db.keys()]
except Exception as ex:
if self._table_not_found(ex):
return []
raise ex

def _is_expired(self, data: JSONDict) -> bool:
expires_at = data.get("expires_at")
if isinstance(expires_at, (int, float)):
return expires_at is None or expires_at <= time.time()
return False

def _decode(self, value: bytes) -> JSONDict:
def _decode(self, raw_payload: bytes) -> JSONDict:
try:
payload = json.loads(value.decode("utf-8"))
data = json.loads(raw_payload.decode("utf-8"))
except json.JSONDecodeError as ex:
LOG.warning(f"Failed to decode cache value: {ex}")
return {}

if not isinstance(payload, dict):
LOG.warning(f"Invalid cache value format: {payload}")
if not isinstance(data, dict):
LOG.warning(f"Invalid cache value format: {raw_payload!r}")
return {}

return payload
return data

def _db_existed(self) -> bool:
return os.path.exists(self._file)

def _table_not_found(self, ex: Exception) -> bool:
if isinstance(ex, sqlite3.OperationalError):
if "no such table" in str(ex):
return True
return False
128 changes: 128 additions & 0 deletions mapillary_tools/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
This module provides a persistent key-value store based on SQLite.

This implementation is mostly copied from dbm.sqlite3 in the Python standard library,
but works for Python >= 3.9, whereas dbm.sqlite3 is only available for Python 3.13.

Source: https://github.com/python/cpython/blob/3.13/Lib/dbm/sqlite3.py
"""

import os
import sqlite3
import sys
from collections.abc import MutableMapping
from contextlib import closing, suppress
from pathlib import Path

BUILD_TABLE = """
CREATE TABLE IF NOT EXISTS Dict (
key BLOB UNIQUE NOT NULL,
value BLOB NOT NULL
)
"""
GET_SIZE = "SELECT COUNT (key) FROM Dict"
LOOKUP_KEY = "SELECT value FROM Dict WHERE key = CAST(? AS BLOB)"
STORE_KV = "REPLACE INTO Dict (key, value) VALUES (CAST(? AS BLOB), CAST(? AS BLOB))"
DELETE_KEY = "DELETE FROM Dict WHERE key = CAST(? AS BLOB)"
ITER_KEYS = "SELECT key FROM Dict"


def _normalize_uri(path):
path = Path(path)
uri = path.absolute().as_uri()
while "//" in uri:
uri = uri.replace("//", "/")
return uri


class KeyValueStore(MutableMapping):
def __init__(self, path, /, *, flag="r", mode=0o666):
"""Open a key-value database and return the object.

The 'path' parameter is the name of the database file.

The optional 'flag' parameter can be one of ...:
'r' (default): open an existing database for read only access
'w': open an existing database for read/write access
'c': create a database if it does not exist; open for read/write access
'n': always create a new, empty database; open for read/write access

The optional 'mode' parameter is the Unix file access mode of the database;
only used when creating a new database. Default: 0o666.
"""
path = os.fsdecode(path)
if flag == "r":
flag = "ro"
elif flag == "w":
flag = "rw"
elif flag == "c":
flag = "rwc"
Path(path).touch(mode=mode, exist_ok=True)
elif flag == "n":
flag = "rwc"
Path(path).unlink(missing_ok=True)
Path(path).touch(mode=mode)
else:
raise ValueError(f"Flag must be one of 'r', 'w', 'c', or 'n', not {flag!r}")

# We use the URI format when opening the database.
uri = _normalize_uri(path)
uri = f"{uri}?mode={flag}"

if sys.version_info >= (3, 12):
# This is the preferred way, but only available in Python 3.10 and newer.
self._cx = sqlite3.connect(uri, autocommit=True, uri=True)
else:
self._cx = sqlite3.connect(uri, uri=True)

# This is an optimization only; it's ok if it fails.
with suppress(sqlite3.OperationalError):
self._cx.execute("PRAGMA journal_mode = wal")

if flag == "rwc":
self._execute(BUILD_TABLE)

def _execute(self, *args, **kwargs):
if sys.version_info >= (3, 12):
return closing(self._cx.execute(*args, **kwargs))
else:
# Use a context manager to commit the changes
with self._cx:
return closing(self._cx.execute(*args, **kwargs))

def __len__(self):
with self._execute(GET_SIZE) as cu:
row = cu.fetchone()
return row[0]

def __getitem__(self, key):
with self._execute(LOOKUP_KEY, (key,)) as cu:
row = cu.fetchone()
if not row:
raise KeyError(key)
return row[0]

def __setitem__(self, key, value):
self._execute(STORE_KV, (key, value))

def __delitem__(self, key):
with self._execute(DELETE_KEY, (key,)) as cu:
if not cu.rowcount:
raise KeyError(key)

def __iter__(self):
with self._execute(ITER_KEYS) as cu:
for row in cu:
yield row[0]

def close(self):
self._cx.close()

def keys(self):
return list(super().keys())

def __enter__(self):
return self

def __exit__(self, *args):
self.close()
2 changes: 1 addition & 1 deletion mapillary_tools/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,7 @@ def _is_uuid(key: str) -> bool:


def _build_upload_cache_path(upload_options: UploadOptions) -> Path:
# Different python/CLI versions use different cache (dbm) formats.
# Different python/CLI versions use different cache formats.
# Separate them to avoid conflicts
py_version_parts = [str(part) for part in sys.version_info[:3]]
version = f"py_{'_'.join(py_version_parts)}_{VERSION}"
Expand Down
Loading
Loading