Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 60 additions & 47 deletions src/ldlite/_database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class Prefix:
def __init__(self, prefix: str):
self._schema: str | None = None
self.schema: str | None = None
sandt = prefix.split(".")
if len(sandt) > 2:
msg = f"Expected one or two identifiers but got {prefix}"
Expand All @@ -24,32 +24,43 @@ def __init__(self, prefix: str):
if len(sandt) == 1:
(self._prefix,) = sandt
else:
(self._schema, self._prefix) = sandt

@property
def schema_name(self) -> sql.Identifier | None:
return None if self._schema is None else sql.Identifier(self._schema)
(self.schema, self._prefix) = sandt

def _identifier(self, table: str) -> sql.Identifier:
if self._schema is None:
if self.schema is None:
return sql.Identifier(table)
return sql.Identifier(self._schema, table)
return sql.Identifier(self.schema, table)

@property
def load_history_key(self) -> str:
return (self._schema or "public") + "." + self._prefix
def schema_identifier(self) -> sql.Identifier | None:
return None if self.schema is None else sql.Identifier(self.schema)

@property
def raw_table_name(self) -> sql.Identifier:
def raw_table_identifier(self) -> sql.Identifier:
return self._identifier(self._prefix)

@property
def catalog_table_name(self) -> sql.Identifier:
return self._identifier(f"{self._prefix}__tcatalog")
def catalog_table_name(self) -> str:
return f"{self._prefix}__tcatalog"

@property
def catalog_table_identifier(self) -> sql.Identifier:
return self._identifier(self.catalog_table_name)

@property
def legacy_jtable_name(self) -> str:
return f"{self._prefix}_jtable"

@property
def legacy_jtable(self) -> sql.Identifier:
return self._identifier(f"{self._prefix}_jtable")
def legacy_jtable_identifier(self) -> sql.Identifier:
return self._identifier(self.legacy_jtable_name)

@property
def load_history_key(self) -> str:
if self.schema is None:
return self._prefix

return self.schema + "." + self._prefix


@dataclass(frozen=True)
Expand Down Expand Up @@ -102,8 +113,9 @@ def __init__(self, conn_factory: Callable[[], DB]):
);""")
conn.commit()

@property
@abstractmethod
def _rollback(self, conn: DB) -> None: ...
def _default_schema(self) -> str: ...

def drop_prefix(
self,
Expand Down Expand Up @@ -134,7 +146,7 @@ def _drop_raw_table(
with closing(conn.cursor()) as cur:
cur.execute(
sql.SQL("DROP TABLE IF EXISTS {table};")
.format(table=prefix.raw_table_name)
.format(table=prefix.raw_table_identifier)
.as_string(),
)

Expand All @@ -146,38 +158,39 @@ def drop_extracted_tables(
self._drop_extracted_tables(conn, prefix)
conn.commit()

@property
@abstractmethod
def _missing_table_error(self) -> type[Exception]: ...
def _drop_extracted_tables(
self,
conn: DB,
prefix: Prefix,
) -> None:
tables: list[Sequence[Sequence[Any]]] = []
with closing(conn.cursor()) as cur:
try:
cur.execute(
sql.SQL("SELECT table_name FROM {catalog};")
.format(catalog=prefix.catalog_table_name)
.as_string(),
)
except self._missing_table_error:
self._rollback(conn)
else:
tables.extend(cur.fetchall())

with closing(conn.cursor()) as cur:
try:
cur.execute(
sql.SQL("SELECT table_name FROM {catalog};")
.format(catalog=prefix.legacy_jtable)
.as_string(),
)
except self._missing_table_error:
self._rollback(conn)
else:
tables.extend(cur.fetchall())
cur.execute(
"""
SELECT table_name FROM information_schema.tables
WHERE table_schema = $1 and table_name IN ($2, $3);""",
(
prefix.schema or self._default_schema,
prefix.catalog_table_name,
prefix.legacy_jtable_name,
),
)
for (tname,) in cur.fetchall():
if tname == prefix.catalog_table_name:
cur.execute(
sql.SQL("SELECT table_name FROM {catalog};")
.format(catalog=prefix.catalog_table_identifier)
.as_string(),
)
tables.extend(cur.fetchall())

if tname == prefix.legacy_jtable_name:
cur.execute(
sql.SQL("SELECT table_name FROM {catalog};")
.format(catalog=prefix.legacy_jtable_identifier)
.as_string(),
)
tables.extend(cur.fetchall())

with closing(conn.cursor()) as cur:
for (et,) in tables:
Expand All @@ -188,12 +201,12 @@ def _drop_extracted_tables(
)
cur.execute(
sql.SQL("DROP TABLE IF EXISTS {catalog};")
.format(catalog=prefix.catalog_table_name)
.format(catalog=prefix.catalog_table_identifier)
.as_string(),
)
cur.execute(
sql.SQL("DROP TABLE IF EXISTS {catalog};")
.format(catalog=prefix.legacy_jtable)
.format(catalog=prefix.legacy_jtable_identifier)
.as_string(),
)

Expand All @@ -206,17 +219,17 @@ def _prepare_raw_table(
prefix: Prefix,
) -> None:
with closing(conn.cursor()) as cur:
if prefix.schema_name is not None:
if prefix.schema_identifier is not None:
cur.execute(
sql.SQL("CREATE SCHEMA IF NOT EXISTS {schema};")
.format(schema=prefix.schema_name)
.format(schema=prefix.schema_identifier)
.as_string(),
)
self._drop_raw_table(conn, prefix)
with closing(conn.cursor()) as cur:
cur.execute(
self._create_raw_table_sql.format(
table=prefix.raw_table_name,
table=prefix.raw_table_identifier,
).as_string(),
)

Expand Down
9 changes: 3 additions & 6 deletions src/ldlite/_database/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@


class DuckDbDatabase(TypedDatabase[duckdb.DuckDBPyConnection]):
def _rollback(self, conn: duckdb.DuckDBPyConnection) -> None:
pass

@property
def _missing_table_error(self) -> type[Exception]:
return duckdb.CatalogException
def _default_schema(self) -> str:
return "main"

@property
def _create_raw_table_sql(self) -> sql.SQL:
Expand All @@ -31,7 +28,7 @@ def ingest_records(
insert_sql = (
sql.SQL("INSERT INTO {table} VALUES(?, ?);")
.format(
table=prefix.raw_table_name,
table=prefix.raw_table_identifier,
)
.as_string()
)
Expand Down
9 changes: 3 additions & 6 deletions src/ldlite/_database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@ def __init__(self, dsn: str):
# same sql between duckdb and postgres
super().__init__(lambda: psycopg.connect(dsn, cursor_factory=psycopg.RawCursor))

def _rollback(self, conn: psycopg.Connection) -> None:
conn.rollback()

@property
def _missing_table_error(self) -> type[Exception]:
return psycopg.errors.UndefinedTable
def _default_schema(self) -> str:
return "public"

@property
def _create_raw_table_sql(self) -> sql.SQL:
Expand All @@ -40,7 +37,7 @@ def ingest_records(
cur.copy(
sql.SQL(
"COPY {table} (__id, jsonb) FROM STDIN (FORMAT BINARY)",
).format(table=prefix.raw_table_name),
).format(table=prefix.raw_table_identifier),
) as copy,
):
# postgres jsonb is always version 1
Expand Down
28 changes: 26 additions & 2 deletions tests/test_cases/load_history_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,31 @@ def case_one_load(self, query: str | None) -> LoadHistoryCase:
},
queries={"prefix": [query]},
expected_loads={
"public.prefix": (query, 2),
"prefix": (query, 2),
},
)

def case_schema_load(self) -> LoadHistoryCase:
return LoadHistoryCase(
values={
"schema.prefix": [
{
"purchaseOrders": [
{
"id": "b096504a-3d54-4664-9bf5-1b872466fd66",
"value": "value",
},
{
"id": "b096504a-9999-4664-9bf5-1b872466fd66",
"value": "value-2",
},
],
},
],
},
queries={"schema.prefix": [None]},
expected_loads={
"schema.prefix": (None, 2),
},
)

Expand Down Expand Up @@ -69,6 +93,6 @@ def case_two_loads(self) -> LoadHistoryCase:
},
queries={"prefix": [None, "a query"]},
expected_loads={
"public.prefix": ("a query", 2),
"prefix": ("a query", 2),
},
)
1 change: 1 addition & 0 deletions tests/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_drop_tables(
dsn = f":memory:{tc.db}"
ld.connect_folio("https://doesnt.matter", "", "", "")
ld.connect_db(dsn)
ld.drop_tables(tc.drop)

for prefix in tc.values:
ld.query(table=prefix, path="/patched", keep_raw=tc.keep_raw)
Expand Down
1 change: 1 addition & 0 deletions tests/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_drop_tables(
dsn = pg_dsn(tc.db)
ld.connect_folio("https://doesnt.matter", "", "", "")
ld.connect_db_postgresql(dsn)
ld.drop_tables(tc.drop)

for prefix in tc.values:
ld.query(table=prefix, path="/patched", keep_raw=tc.keep_raw)
Expand Down
Loading