Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 39 additions & 58 deletions src/ldlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
"""

import sys
from itertools import count
from typing import TYPE_CHECKING, NoReturn, cast

import duckdb
Expand All @@ -44,19 +43,20 @@
from tqdm import tqdm

from ._csv import to_csv
from ._database import Prefix
from ._database import Database, Prefix
from ._folio import FolioClient
from ._jsonx import Attr, transform_json
from ._select import select
from ._sqlx import (
DBType,
DBTypeDatabase,
as_postgres,
autocommit,
sqlid,
)

if TYPE_CHECKING:
from collections.abc import Iterator

from _typeshed import dbapi
from httpx_folio.query import QueryType

Expand All @@ -77,7 +77,7 @@ def __init__(self) -> None:
self._quiet = False
self.dbtype: DBType = DBType.UNDEFINED
self.db: dbapi.DBAPIConnection | None = None
self._db: DBTypeDatabase | None = None
self._db: Database | None = None
self._folio: FolioClient | None = None
self.page_size = 1000
self._okapi_timeout = 60
Expand Down Expand Up @@ -124,14 +124,13 @@ def _connect_db_duckdb(
db = ld.connect_db_duckdb(filename='ldlite.db')

"""
from ._database.duckdb import DuckDbDatabase # noqa: PLC0415

self.dbtype = DBType.DUCKDB
fn = filename if filename is not None else ":memory:"
db = duckdb.connect(database=fn)
self.db = cast("dbapi.DBAPIConnection", db.cursor())
self._db = DBTypeDatabase(
DBType.DUCKDB,
lambda: cast("dbapi.DBAPIConnection", db.cursor()),
)
self._db = DuckDbDatabase(lambda: db.cursor())

return db.cursor()

Expand All @@ -146,13 +145,12 @@ def connect_db_postgresql(self, dsn: str) -> psycopg.Connection:
db = ld.connect_db_postgresql(dsn='dbname=ld host=localhost user=ldlite')

"""
from ._database.postgres import PostgresDatabase # noqa: PLC0415

self.dbtype = DBType.POSTGRES
db = psycopg.connect(dsn)
self.db = cast("dbapi.DBAPIConnection", db)
self._db = DBTypeDatabase(
DBType.POSTGRES,
lambda: cast("dbapi.DBAPIConnection", psycopg.connect(dsn)),
)
self._db = PostgresDatabase(lambda: psycopg.connect(dsn))

ret_db = psycopg.connect(dsn)
ret_db.rollback()
Expand Down Expand Up @@ -200,9 +198,6 @@ def drop_tables(self, table: str) -> None:
if self.db is None or self._db is None:
self._check_db()
return
schema_table = table.strip().split(".")
if len(schema_table) != 1 and len(schema_table) != 2:
raise ValueError("invalid table name: " + table)
prefix = Prefix(table)
self._db.drop_prefix(prefix)

Expand Down Expand Up @@ -293,9 +288,6 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
"use json_depth=0 to disable JSON transformation"
)
raise ValueError(msg)
schema_table = table.split(".")
if len(schema_table) != 1 and len(schema_table) != 2:
raise ValueError("invalid table name: " + table)
if json_depth is None or json_depth < 0 or json_depth > 4:
raise ValueError("invalid value for json_depth: " + str(json_depth))
if self._folio is None:
Expand All @@ -308,57 +300,39 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
if not self._quiet:
print("ldlite: querying: " + path, file=sys.stderr)
try:
# First get total number of records
records = self._folio.iterate_records(
(total_records, records) = self._folio.iterate_records(
path,
self._okapi_timeout,
self._okapi_max_retries,
self.page_size,
query=cast("QueryType", query),
)
(total_records, _) = next(records)
total = min(total_records, limit or total_records)
if limit is not None:
total_records = min(total_records, limit)
records = (x for _, x in zip(range(limit), records, strict=False))
if self._verbose:
print("ldlite: estimated row count: " + str(total), file=sys.stderr)

class PbarNoop:
def update(self, _: int) -> None: ...
def close(self) -> None: ...

p_count = count(1)
processed = 0
pbar: tqdm | PbarNoop # type:ignore[type-arg]
if not self._quiet:
pbar = tqdm(
desc="reading",
total=total,
leave=False,
mininterval=3,
smoothing=0,
colour="#A9A9A9",
bar_format="{desc} {bar}{postfix}",
print(
"ldlite: estimated row count: " + str(total_records),
file=sys.stderr,
)
else:
pbar = PbarNoop()

def on_processed() -> bool:
pbar.update(1)
nonlocal processed
processed = next(p_count)
return True

def on_processed_limit() -> bool:
pbar.update(1)
nonlocal processed, limit
processed = next(p_count)
return limit is None or processed < limit

self._db.ingest_records(
processed = self._db.ingest_records(
prefix,
on_processed_limit if limit is not None else on_processed,
records,
cast(
"Iterator[bytes]",
tqdm(
records,
desc="downloading",
total=total_records,
leave=False,
mininterval=5,
disable=self._quiet,
unit=table.split(".")[-1],
unit_scale=True,
delay=5,
),
),
)
pbar.close()

self._db.drop_extracted_tables(prefix)
newtables = [table]
Expand Down Expand Up @@ -386,6 +360,13 @@ def on_processed_limit() -> bool:
autocommit(self.db, self.dbtype, True)
# Create indexes on id columns (for postgres)
if self.dbtype == DBType.POSTGRES:

class PbarNoop:
def update(self, _: int) -> None: ...
def close(self) -> None: ...

pbar: tqdm | PbarNoop = PbarNoop() # type:ignore[type-arg]

indexable_attrs = [
(t, a)
for t, attrs in newattrs.items()
Expand Down
60 changes: 27 additions & 33 deletions src/ldlite/_database.py → src/ldlite/_database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from psycopg import sql

if TYPE_CHECKING:
from _typeshed import dbapi

DB = TypeVar("DB", bound="dbapi.DBAPIConnection")
import duckdb
import psycopg


class Prefix:
def __init__(self, table: str):
def __init__(self, prefix: str):
self._schema: str | None = None
sandt = table.split(".")
sandt = prefix.split(".")
if len(sandt) > 2:
msg = f"Expected one or two identifiers but got {prefix}"
raise ValueError(msg)

if len(sandt) == 1:
(self._prefix,) = sandt
else:
Expand Down Expand Up @@ -42,7 +45,24 @@ def legacy_jtable(self) -> sql.Identifier:
return self.identifier(f"{self._prefix}_jtable")


class Database(ABC, Generic[DB]):
class Database(ABC):
@abstractmethod
def drop_prefix(self, prefix: Prefix) -> None: ...

@abstractmethod
def drop_raw_table(self, prefix: Prefix) -> None: ...

@abstractmethod
def drop_extracted_tables(self, prefix: Prefix) -> None: ...

@abstractmethod
def ingest_records(self, prefix: Prefix, records: Iterator[bytes]) -> int: ...


DB = TypeVar("DB", bound="duckdb.DuckDBPyConnection | psycopg.Connection")


class TypedDatabase(Database, Generic[DB]):
def __init__(self, conn_factory: Callable[[], DB]):
self._conn_factory = conn_factory

Expand Down Expand Up @@ -88,7 +108,7 @@ def drop_extracted_tables(

@property
@abstractmethod
def _missing_table_error(self) -> tuple[type[Exception], ...]: ...
def _missing_table_error(self) -> type[Exception]: ...
def _drop_extracted_tables(
self,
conn: DB,
Expand Down Expand Up @@ -137,9 +157,6 @@ def _drop_extracted_tables(
.as_string(),
)

@property
@abstractmethod
def _truncate_raw_table_sql(self) -> sql.SQL: ...
@property
@abstractmethod
def _create_raw_table_sql(self) -> sql.SQL: ...
Expand All @@ -162,26 +179,3 @@ def _prepare_raw_table(
table=prefix.raw_table_name,
).as_string(),
)

@property
@abstractmethod
def _insert_record_sql(self) -> sql.SQL: ...
def ingest_records(
self,
prefix: Prefix,
on_processed: Callable[[], bool],
records: Iterator[tuple[int, bytes]],
) -> None:
with closing(self._conn_factory()) as conn:
self._prepare_raw_table(conn, prefix)

insert_sql = self._insert_record_sql.format(
table=prefix.raw_table_name,
).as_string()
with closing(conn.cursor()) as cur:
for pkey, r in records:
cur.execute(insert_sql, (pkey, r.decode()))
if not on_processed():
break

conn.commit()
44 changes: 44 additions & 0 deletions src/ldlite/_database/duckdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from collections.abc import Iterator
from itertools import count

import duckdb
from psycopg import sql

from . import Prefix, TypedDatabase


class DuckDbDatabase(TypedDatabase[duckdb.DuckDBPyConnection]):
def _rollback(self, conn: duckdb.DuckDBPyConnection) -> None:
pass

@property
def _missing_table_error(self) -> type[Exception]:
return duckdb.CatalogException

@property
def _create_raw_table_sql(self) -> sql.SQL:
return sql.SQL("CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb text);")

def ingest_records(
self,
prefix: Prefix,
records: Iterator[bytes],
) -> int:
pkey = count(1)
with self._conn_factory() as conn:
self._prepare_raw_table(conn, prefix)

insert_sql = (
sql.SQL("INSERT INTO {table} VALUES(?, ?);")
.format(
table=prefix.raw_table_name,
)
.as_string()
)
# duckdb has better performance bulk inserting in a transaction
with conn.begin() as tx, tx.cursor() as cur:
for r in records:
cur.execute(insert_sql, (next(pkey), r.decode()))
tx.commit()

return next(pkey) - 1
51 changes: 51 additions & 0 deletions src/ldlite/_database/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from collections.abc import Iterator
from itertools import count

import psycopg
from psycopg import sql

from . import Prefix, TypedDatabase


class PostgresDatabase(TypedDatabase[psycopg.Connection]):
def _rollback(self, conn: psycopg.Connection) -> None:
conn.rollback()

@property
def _missing_table_error(self) -> type[Exception]:
return psycopg.errors.UndefinedTable

@property
def _create_raw_table_sql(self) -> sql.SQL:
return sql.SQL(
"CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb jsonb);",
)

def ingest_records(
self,
prefix: Prefix,
records: Iterator[bytes],
) -> int:
pkey = count(1)
with self._conn_factory() as conn:
self._prepare_raw_table(conn, prefix)

with (
conn.cursor() as cur,
cur.copy(
sql.SQL(
"COPY {table} (__id, jsonb) FROM STDIN (FORMAT BINARY)",
).format(table=prefix.raw_table_name),
) as copy,
):
# postgres jsonb is always version 1
# and it always goes in front
jver = (1).to_bytes(1, "big")
for r in records:
rb = bytearray()
rb.extend(jver)
rb.extend(r)
copy.write_row((next(pkey).to_bytes(4, "big"), rb))

conn.commit()
return next(pkey) - 1
Loading
Loading