Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/ldlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"""

import sys
from datetime import datetime, timezone
from typing import TYPE_CHECKING, NoReturn, cast

import duckdb
Expand All @@ -43,7 +44,7 @@
from tqdm import tqdm

from ._csv import to_csv
from ._database import Database, Prefix
from ._database import Database, LoadHistory, Prefix
from ._folio import FolioClient
from ._jsonx import Attr, transform_json
from ._select import select
Expand Down Expand Up @@ -150,7 +151,7 @@ def connect_db_postgresql(self, dsn: str) -> psycopg.Connection:
self.dbtype = DBType.POSTGRES
db = psycopg.connect(dsn)
self.db = cast("dbapi.DBAPIConnection", db)
self._db = PostgresDatabase(lambda: psycopg.connect(dsn))
self._db = PostgresDatabase(dsn)

ret_db = psycopg.connect(dsn)
ret_db.rollback()
Expand Down Expand Up @@ -296,6 +297,7 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
if self.db is None or self._db is None:
self._check_db()
return []
start = datetime.now(timezone.utc)
prefix = Prefix(table)
if not self._quiet:
print("ldlite: querying: " + path, file=sys.stderr)
Expand Down Expand Up @@ -333,13 +335,15 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
),
),
)
download = datetime.now(timezone.utc)
scan = datetime.now(timezone.utc)

self._db.drop_extracted_tables(prefix)
newtables = [table]
newattrs = {}
if json_depth > 0:
autocommit(self.db, self.dbtype, False)
jsontables, jsonattrs = transform_json(
(jsontables, jsonattrs, scan) = transform_json(
self.db,
self.dbtype,
table,
Expand All @@ -357,6 +361,7 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
self._db.drop_raw_table(prefix)

finally:
transformed = datetime.now(timezone.utc)
autocommit(self.db, self.dbtype, True)
# Create indexes on id columns (for postgres)
if self.dbtype == DBType.POSTGRES:
Expand Down Expand Up @@ -398,6 +403,19 @@ def close(self) -> None: ...
cur.close()
pbar.update(1)
pbar.close()
index = datetime.now(timezone.utc)
self._db.record_history(
LoadHistory(
prefix,
query if query and isinstance(query, str) else None,
start,
download,
scan,
transformed,
index,
processed,
),
)
# Return table names
if not self._quiet:
print("ldlite: created tables: " + ", ".join(newtables), file=sys.stderr)
Expand Down
75 changes: 71 additions & 4 deletions src/ldlite/_database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import datetime
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterator, Sequence
from contextlib import closing
from dataclasses import dataclass
from datetime import timezone
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast

from psycopg import sql
Expand All @@ -27,22 +30,38 @@ def __init__(self, prefix: str):
def schema_name(self) -> sql.Identifier | None:
return None if self._schema is None else sql.Identifier(self._schema)

def identifier(self, table: str) -> sql.Identifier:
def _identifier(self, table: str) -> sql.Identifier:
if self._schema is None:
return sql.Identifier(table)
return sql.Identifier(self._schema, table)

@property
def load_history_key(self) -> str:
return (self._schema or "public") + "." + self._prefix

@property
def raw_table_name(self) -> sql.Identifier:
return self.identifier(self._prefix)
return self._identifier(self._prefix)

@property
def catalog_table_name(self) -> sql.Identifier:
return self.identifier(f"{self._prefix}__tcatalog")
return self._identifier(f"{self._prefix}__tcatalog")

@property
def legacy_jtable(self) -> sql.Identifier:
return self.identifier(f"{self._prefix}_jtable")
return self._identifier(f"{self._prefix}_jtable")


@dataclass(frozen=True)
class LoadHistory:
table_name: Prefix
query: str | None
start: datetime.datetime
download: datetime.datetime
scan: datetime.datetime
transform: datetime.datetime
index: datetime.datetime
total: int


class Database(ABC):
Expand All @@ -58,13 +77,30 @@ def drop_extracted_tables(self, prefix: Prefix) -> None: ...
@abstractmethod
def ingest_records(self, prefix: Prefix, records: Iterator[bytes]) -> int: ...

@abstractmethod
def record_history(self, history: LoadHistory) -> None: ...


DB = TypeVar("DB", bound="duckdb.DuckDBPyConnection | psycopg.Connection")


class TypedDatabase(Database, Generic[DB]):
def __init__(self, conn_factory: Callable[[], DB]):
self._conn_factory = conn_factory
with closing(self._conn_factory()) as conn, conn.cursor() as cur:
cur.execute('CREATE SCHEMA IF NOT EXISTS "ldlite_system";')
cur.execute("""
CREATE TABLE IF NOT EXISTS "ldlite_system"."load_history" (
"table_name" TEXT UNIQUE
,"query" TEXT
,"start_utc" TIMESTAMP
,"download_complete_utc" TIMESTAMP
,"scan_complete_utc" TIMESTAMP
,"transformation_complete_utc" TIMESTAMP
,"index_complete_utc" TIMESTAMP
,"row_count" INTEGER
);""")
conn.commit()

@abstractmethod
def _rollback(self, conn: DB) -> None: ...
Expand All @@ -76,6 +112,10 @@ def drop_prefix(
with closing(self._conn_factory()) as conn:
self._drop_extracted_tables(conn, prefix)
self._drop_raw_table(conn, prefix)
conn.execute(
'DELETE FROM "ldlite_system"."load_history" WHERE "table_name" = $1',
(prefix.load_history_key,),
)
conn.commit()

def drop_raw_table(
Expand Down Expand Up @@ -179,3 +219,30 @@ def _prepare_raw_table(
table=prefix.raw_table_name,
).as_string(),
)

def record_history(self, history: LoadHistory) -> None:
with closing(self._conn_factory()) as conn, conn.cursor() as cur:
cur.execute(
"""
INSERT INTO "ldlite_system"."load_history" VALUES($1,$2,$3,$4,$5,$6,$7,$8)
ON CONFLICT ("table_name") DO UPDATE SET
"query" = EXCLUDED."query"
,"start_utc" = EXCLUDED."start_utc"
,"download_complete_utc" = EXCLUDED."download_complete_utc"
,"scan_complete_utc" = EXCLUDED."scan_complete_utc"
,"transformation_complete_utc" = EXCLUDED."transformation_complete_utc"
,"index_complete_utc" = EXCLUDED."index_complete_utc"
,"row_count" = EXCLUDED."row_count"
""",
(
history.table_name.load_history_key,
history.query,
history.start.astimezone(timezone.utc),
history.download.astimezone(timezone.utc),
history.scan.astimezone(timezone.utc),
history.transform.astimezone(timezone.utc),
history.index.astimezone(timezone.utc),
history.total,
),
)
conn.commit()
5 changes: 5 additions & 0 deletions src/ldlite/_database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@


class PostgresDatabase(TypedDatabase[psycopg.Connection]):
def __init__(self, dsn: str):
# RawCursor lets us use $1, $2, etc to use the
# same sql between duckdb and postgres
super().__init__(lambda: psycopg.connect(dsn, cursor_factory=psycopg.RawCursor))

def _rollback(self, conn: psycopg.Connection) -> None:
conn.rollback()

Expand Down
10 changes: 6 additions & 4 deletions src/ldlite/_jsonx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Literal, Union

import duckdb
Expand Down Expand Up @@ -407,7 +408,7 @@ def transform_json( # noqa: C901, PLR0912, PLR0913, PLR0915
total: int,
quiet: bool,
max_depth: int,
) -> tuple[list[str], dict[str, dict[str, Attr]]]:
) -> tuple[list[str], dict[str, dict[str, Attr]], datetime]:
# Scan all fields for JSON data
# First get a list of the string attributes
str_attrs: list[str] = []
Expand All @@ -420,7 +421,7 @@ def transform_json( # noqa: C901, PLR0912, PLR0913, PLR0915
cur.close()
# Scan data for JSON objects
if len(str_attrs) == 0:
return [], {}
return [], {}, datetime.now(timezone.utc)
json_attrs: list[str] = []
json_attrs_set: set[str] = set()
newattrs: dict[str, dict[str, Attr]] = {}
Expand Down Expand Up @@ -511,14 +512,15 @@ def transform_json( # noqa: C901, PLR0912, PLR0913, PLR0915
finally:
cur.close()
db.commit()
scan = datetime.now(timezone.utc)
# Set all row IDs to 1
row_ids = {}
for t in newattrs:
row_ids[t] = 1
# Run transformation
# Select only JSON columns
if len(json_attrs) == 0:
return [], {}
return [], {}, scan
cur = server_cursor(db, dbtype)
try:
cur.execute(
Expand Down Expand Up @@ -608,4 +610,4 @@ def transform_json( # noqa: C901, PLR0912, PLR0913, PLR0915
finally:
cur.close()
db.commit()
return sorted([*list(newattrs.keys()), tcatalog]), newattrs
return sorted([*list(newattrs.keys()), tcatalog]), newattrs, scan
31 changes: 19 additions & 12 deletions tests/test_cases/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import MagicMock
from uuid import uuid4

Expand All @@ -11,7 +11,7 @@

@dataclass(frozen=True)
class EndToEndTestCase:
values: dict[str, list[dict[str, Any]]]
values: dict[str, list[dict[str, Any]] | list[list[dict[str, Any]]]]

@cached_property
def db(self) -> str:
Expand All @@ -33,20 +33,27 @@ def patch_request_get(
httpx_post_mock.return_value.cookies.__getitem__.return_value = "token"

side_effects = []
for values in self.values.values():
key = next(iter(values[0].keys()))
for vsource in self.values.values():
list_values = (
[cast("list[dict[str, Any]]", vsource)]
if isinstance(vsource[0], dict)
else cast("list[list[dict[str, Any]]]", vsource)
)

key = next(iter(list_values[0][0].keys()))
total_mock = MagicMock()
total_mock.text = f'{{"{key}": [{{"id": ""}}], "totalRecords": 100000}}'

value_mocks = []
for v in values:
value_mock = MagicMock()
value_mock.text = json.dumps(v)
value_mocks.append(value_mock)
for values in list_values:
value_mocks = []
for v in values:
value_mock = MagicMock()
value_mock.text = json.dumps(v)
value_mocks.append(value_mock)

end_mock = MagicMock()
end_mock.text = f'{{"{key}": [] }}'
end_mock = MagicMock()
end_mock.text = f'{{"{key}": [] }}'

side_effects.extend([total_mock, *value_mocks, end_mock])
side_effects.extend([total_mock, *value_mocks, end_mock])

client_get_mock.side_effect = side_effects
74 changes: 74 additions & 0 deletions tests/test_cases/load_history_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from dataclasses import dataclass

from pytest_cases import parametrize

from .base import EndToEndTestCase


@dataclass(frozen=True)
class LoadHistoryCase(EndToEndTestCase):
queries: dict[str, list[str | None | dict[str, str]]]
expected_loads: dict[str, tuple[str | None, int]]


class LoadHistoryTestCases:
@parametrize(query=[None, "poline.id=*A"])
def case_one_load(self, query: str | None) -> LoadHistoryCase:
return LoadHistoryCase(
values={
"prefix": [
{
"purchaseOrders": [
{
"id": "b096504a-3d54-4664-9bf5-1b872466fd66",
"value": "value",
},
{
"id": "b096504a-9999-4664-9bf5-1b872466fd66",
"value": "value-2",
},
],
},
],
},
queries={"prefix": [query]},
expected_loads={
"public.prefix": (query, 2),
},
)

def case_two_loads(self) -> LoadHistoryCase:
return LoadHistoryCase(
values={
"prefix": [
[
{
"purchaseOrders": [
{
"id": "b096504a-3d54-4664-9bf5-1b872466fd66",
"value": "value",
},
],
},
],
[
{
"purchaseOrders": [
{
"id": "b096504a-3d54-4664-9bf5-1b872466fd66",
"value": "value",
},
{
"id": "b096504a-9999-4664-9bf5-1b872466fd66",
"value": "value-2",
},
],
},
],
],
},
queries={"prefix": [None, "a query"]},
expected_loads={
"public.prefix": ("a query", 2),
},
)
Loading
Loading