Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 233 additions & 80 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import binascii
import re
import uuid
import weakref
from collections.abc import Callable, Sequence
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -58,6 +60,78 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None:
data[k] = v.item()


# Module-level (not methods) so they can be registered with ``weakref.finalize``
# without holding a bound reference to the owning object, which would prevent
# it from ever being collected.
def _drop_scratch_table(engine: sa.Engine, table: sa.Table) -> None:
"""Drop a scratch table, swallowing errors (e.g. at interpreter shutdown)."""
try:
table.drop(engine)
except Exception as exc:
LOG.debug("Failed to drop scratch table %s: %s", table.name, exc)


class _SQLIDSet:
"""A set of ids usable in SQL ``IN`` clauses without overflowing bind limits.

Small sets compile to inline ``col.in_([...])``; larger sets are materialized
into a per-instance scratch table on ``graph._engine`` and matched via
``col.in_(SELECT id FROM scratch)``. The scratch table is a regular table
(not ``TEMPORARY``) so it is visible from any pool connection the filter
later uses; the caller drops it via :meth:`close` once the queries that
reference it are no longer needed.

``occurrences`` is the maximum number of times the id set will be expanded
in a single compiled statement (e.g. filtering both ``source_id`` and
``target_id`` of an edge table counts as 2). The scratch-table cutoff is
divided by it so that ``len(ids) * occurrences`` stays safely under the
backend's bound-variable limit.
"""

def __init__(
self,
graph: "SQLGraph",
ids: Sequence[int],
*,
occurrences: int = 1,
) -> None:
if hasattr(ids, "tolist"):
ids = ids.tolist()
self._ids: list[int] = list(ids)
# Hold the engine, not the graph, so this set does not participate in
# the graph -> SQLFilter -> _SQLIDSet -> graph reference cycle.
# Otherwise the scratch table would only be dropped after Python's
# cycle GC runs, delaying cleanup in long-running processes.
self._engine = graph._engine

limit = max(1, graph._sql_chunk_size() // max(1, occurrences))
if len(self._ids) > limit:
self._scratch: sa.Table | None = graph._create_id_scratch_table(self._ids)
else:
self._scratch = None

@property
def uses_scratch_table(self) -> bool:
return self._scratch is not None

def in_clause(self, column: sa.ColumnElement) -> "sa.ColumnElement[bool]":
if self._scratch is None:
return column.in_(self._ids)
return column.in_(sa.select(self._scratch.c.id))

def close(self) -> None:
if self._scratch is not None:
_drop_scratch_table(self._engine, self._scratch)
self._scratch = None


def _close_id_set(id_set: "_SQLIDSet") -> None:
try:
id_set.close()
except Exception as exc:
LOG.debug("Failed to close _SQLIDSet: %s", exc)


def _filter_query(
query: sa.Select,
table: type[DeclarativeBase],
Expand Down Expand Up @@ -88,6 +162,17 @@ def _filter_query(


class SQLFilter(BaseFilter):
"""SQL-backed filter over an :class:`SQLGraph`.

When ``node_ids`` is larger than the backend's bound-variable budget
(after accounting for how many ``IN (...)`` clauses the list expands
into), the filter materializes the ids into a per-instance scratch
table on ``graph._engine`` and references it via subselects. The
scratch table is dropped when the filter is garbage-collected (via
:func:`weakref.finalize`), so callers don't need to close the filter
explicitly.
"""

def __init__(
self,
*attr_filters: AttrComparison,
Expand All @@ -101,25 +186,28 @@ def __init__(
self._node_attr_comps, self._edge_attr_comps = split_attr_comps(attr_filters)
self._include_targets = include_targets
self._include_sources = include_sources
self._id_set: _SQLIDSet | None = None

# creating initial query
self._node_query: sa.Select = sa.select(self._graph.Node)
self._edge_query: sa.Select = sa.select(self._graph.Edge)
node_filtered = False

if node_ids is not None:
if hasattr(node_ids, "tolist"):
node_ids = node_ids.tolist()

self._node_query = self._node_query.filter(self._graph.Node.node_id.in_(node_ids))
# The node_ids list is expanded in up to three IN(...) clauses
# below (once on Node, plus once each on Edge.target_id /
# Edge.source_id unless the corresponding ``include_*`` is set).
# Account for that so the inline/scratch cutoff stays below the
# backend's bound-variable limit for the compiled statement.
occurrences = 1 + int(not self._include_targets) + int(not self._include_sources)
id_set = _SQLIDSet(self._graph, node_ids, occurrences=occurrences)
self._id_set = id_set

self._node_query = self._node_query.filter(id_set.in_clause(self._graph.Node.node_id))
if not self._include_targets:
self._edge_query = self._edge_query.filter(
self._graph.Edge.target_id.in_(node_ids),
)
self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.target_id))
if not self._include_sources:
self._edge_query = self._edge_query.filter(
self._graph.Edge.source_id.in_(node_ids),
)
self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.source_id))
node_filtered = True

if self._node_attr_comps:
Expand Down Expand Up @@ -184,6 +272,16 @@ def __init__(

self._node_query = sa.union(*nodes_query)

# Drop the scratch table when this filter is collected. Only register a
# finalizer if one was actually allocated, so the common small-set case
# stays free of weakref bookkeeping.
if self._uses_scratch_table():
weakref.finalize(self, _close_id_set, self._id_set)

def _uses_scratch_table(self) -> bool:
"""Whether the id set backing this filter materialized a scratch table."""
return self._id_set is not None and self._id_set.uses_scratch_table

@cache_method
def node_ids(self) -> list[int]:
"""
Expand Down Expand Up @@ -1943,6 +2041,40 @@ def _chunked_sa_read(
chunks.append(data_df)
return pl.concat(chunks)

def _create_id_scratch_table(self, ids: Sequence[int]) -> sa.Table:
"""Create a uniquely-named helper table holding ``ids`` on ``self._engine``.

Used to work around SQL bound-variable limits when filtering by large
``IN (...)`` lists: callers replace ``col.in_(ids)`` with
``col.in_(sa.select(table.c.id))``. The table is a regular table on
the engine (not ``TEMPORARY``), so it is visible from any session or
connection drawn from the same engine pool — that is what makes it
usable across the multiple ``Session(engine)`` calls inside
:class:`SQLFilter`.

The caller owns the table's lifetime and must eventually call
``table.drop(self._engine)`` (or hand the table off to a finalizer
that does so) to remove it.
"""
unique_ids = list({int(v) for v in ids})

name = f"_tracksdata_ids_{uuid.uuid4().hex}"
table = sa.Table(
name,
sa.MetaData(),
sa.Column("id", sa.BigInteger, primary_key=True),
)
table.create(self._engine)

chunk_size = max(1, self._sql_chunk_size())
with self._engine.begin() as conn:
for i in range(0, len(unique_ids), chunk_size):
conn.execute(
table.insert(),
[{"id": v} for v in unique_ids[i : i + chunk_size]],
)
return table

def update_node_attrs(
self,
*,
Expand Down Expand Up @@ -2037,13 +2169,22 @@ def _get_degree(
with Session(self._engine) as session:
return int(session.execute(stmt).scalar())

stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col)
if node_ids is not None:
stmt = stmt.where(edge_key_col.in_(node_ids))
base_stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col)

degree: dict[int, int] = {}
with Session(self._engine) as session:
# get the number of edges for each using group by and count
degree = dict(session.execute(stmt).all())
if node_ids is None:
degree.update(session.execute(base_stmt).all())
else:
# Chunk the IN(...) so the bound-parameter count stays below
# the backend's limit (notably SQLite's
# ``SQLITE_MAX_VARIABLE_NUMBER``). Each chunk's group-by result
# is disjoint, so we can merge them with a simple dict update.
chunk_size = max(1, self._sql_chunk_size())
for i in range(0, len(node_ids), chunk_size):
chunk = node_ids[i : i + chunk_size]
stmt = base_stmt.where(edge_key_col.in_(chunk))
degree.update(session.execute(stmt).all())

if node_ids is None:
# this is necessary to make sure it's the same order as node_ids
Expand Down Expand Up @@ -2158,9 +2299,10 @@ def _sqlite_table_dump(
reflection path then rebuilds the in-memory state.

For filtered copies (``source_node_ids`` not ``None``) the selection
is materialized in a temp table so the row filter joins instead of
using an oversized ``IN (...)`` clause that would hit SQLite's
bound-parameter limit.
is materialized in a per-instance scratch table on the source engine
so the row filter joins instead of using an oversized ``IN (...)``
clause that would hit SQLite's bound-parameter limit. The scratch
table is dropped in the ``finally`` block before returning.
"""
dst_database: str = kwargs["database"]
dst_path = Path(dst_database)
Expand All @@ -2176,69 +2318,80 @@ def _sqlite_table_dump(
# escape the path safely via single-quote doubling.
attach_path = dst_database.replace("'", "''")

with source_root._engine.connect() as conn:
conn.exec_driver_sql(f"ATTACH DATABASE '{attach_path}' AS _td_dst")
try:
# 1. Replicate the source schema by replaying its DDL against
# the attached destination. ``sqlite_master.sql`` is NULL for
# auto-generated objects (e.g. PK indexes), which we skip;
# tables are created before indexes.
ddl_rows = conn.exec_driver_sql(
"SELECT type, sql FROM main.sqlite_master "
"WHERE sql IS NOT NULL AND type IN ('table', 'index') "
"ORDER BY CASE type WHEN 'table' THEN 0 ELSE 1 END"
).fetchall()
for _type, ddl in ddl_rows:
qualified = cls._SQLITE_DDL_QUALIFIER.sub(r"\g<1>_td_dst.", ddl, count=1)
conn.exec_driver_sql(qualified)

# 2. Copy rows. The Metadata table is included verbatim — its
# SQL-private schema entries describe the columns we just
# cloned and so are valid for the destination as-is.
if source_node_ids is None:
for table_name in ("Node", "Edge", "Overlap", "Metadata"):
conn.exec_driver_sql(f'INSERT INTO _td_dst."{table_name}" SELECT * FROM main."{table_name}"')
else:
node_ids = list(source_node_ids)
if hasattr(node_ids, "tolist"):
node_ids = node_ids.tolist()
# Materialize the selection in a temp table so the row
# filter joins instead of using an oversized IN(...) clause.
conn.exec_driver_sql("CREATE TEMP TABLE _td_selected (node_id INTEGER PRIMARY KEY)")
insert_stmt = sa.text("INSERT INTO _td_selected (node_id) VALUES (:node_id)")
chunk_size = max(1, source_root._sql_chunk_size())
for i in range(0, len(node_ids), chunk_size):
batch = node_ids[i : i + chunk_size]
conn.execute(
insert_stmt,
[{"node_id": int(nid)} for nid in batch],
if source_node_ids is None:
selected: sa.Table | None = None
else:
# Materialize the selection in a per-instance scratch table so the
# row filter joins instead of expanding into an oversized IN(...).
# The table lives on ``source_root._engine`` (visible from the
# ATTACH-ing connection) and is dropped in the outer ``finally``.
#
# We deliberately do not use a ``TEMPORARY`` table here even
# though this function holds a single connection. SQLAlchemy's
# ``Connection.close()`` only returns the underlying DB-API
# connection to the pool, it does not destroy it, so a TEMP table
# would survive into the next consumer of that same pooled SQLite
# connection. A regular table dropped explicitly avoids that.
selected = source_root._create_id_scratch_table(source_node_ids)

try:
with source_root._engine.connect() as conn:
conn.exec_driver_sql(f"ATTACH DATABASE '{attach_path}' AS _td_dst")
try:
# 1. Replicate the source schema by replaying its DDL against
# the attached destination. ``sqlite_master.sql`` is NULL for
# auto-generated objects (e.g. PK indexes), which we skip;
# tables are created before indexes. ``_tracksdata_ids_*``
# are internal scratch tables (this call's own ``selected``
# plus any from live ``SQLFilter``s on the same engine) and
# must not be copied into the persisted destination.
ddl_rows = conn.exec_driver_sql(
"SELECT type, sql FROM main.sqlite_master "
"WHERE sql IS NOT NULL AND type IN ('table', 'index') "
"AND name NOT GLOB '_tracksdata_ids_*' "
"ORDER BY CASE type WHEN 'table' THEN 0 ELSE 1 END"
).fetchall()
for _type, ddl in ddl_rows:
qualified = cls._SQLITE_DDL_QUALIFIER.sub(r"\g<1>_td_dst.", ddl, count=1)
conn.exec_driver_sql(qualified)

# 2. Copy rows. The Metadata table is included verbatim — its
# SQL-private schema entries describe the columns we just
# cloned and so are valid for the destination as-is.
if selected is None:
for table_name in ("Node", "Edge", "Overlap", "Metadata"):
conn.exec_driver_sql(
f'INSERT INTO _td_dst."{table_name}" SELECT * FROM main."{table_name}"'
)
else:
selected_subq = f'SELECT id FROM "{selected.name}"'

conn.exec_driver_sql(
f'INSERT INTO _td_dst."Node" SELECT * FROM main."Node" WHERE node_id IN ({selected_subq})'
)

conn.exec_driver_sql(
'INSERT INTO _td_dst."Node" SELECT * FROM main."Node" '
"WHERE node_id IN (SELECT node_id FROM _td_selected)"
)
conn.exec_driver_sql(
'INSERT INTO _td_dst."Edge" SELECT * FROM main."Edge" '
"WHERE source_id IN (SELECT node_id FROM _td_selected) "
"AND target_id IN (SELECT node_id FROM _td_selected)"
)
conn.exec_driver_sql(
'INSERT INTO _td_dst."Overlap" SELECT * FROM main."Overlap" '
"WHERE source_id IN (SELECT node_id FROM _td_selected) "
"AND target_id IN (SELECT node_id FROM _td_selected)"
)
conn.exec_driver_sql('INSERT INTO _td_dst."Metadata" SELECT * FROM main."Metadata"')
conn.exec_driver_sql("DROP TABLE _td_selected")

conn.commit()
finally:
conn.exec_driver_sql("DETACH DATABASE _td_dst")

# 3. Open the destination from the now-populated file. The standard
# constructor reflects the schema, restores pickled column types,
# and recomputes ``_max_id_per_time``.
return cls(**kwargs)
conn.exec_driver_sql(
f'INSERT INTO _td_dst."Edge" SELECT * FROM main."Edge" '
f"WHERE source_id IN ({selected_subq}) "
f"AND target_id IN ({selected_subq})"
)
conn.exec_driver_sql(
f'INSERT INTO _td_dst."Overlap" SELECT * FROM main."Overlap" '
f"WHERE source_id IN ({selected_subq}) "
f"AND target_id IN ({selected_subq})"
)
conn.exec_driver_sql('INSERT INTO _td_dst."Metadata" SELECT * FROM main."Metadata"')

conn.commit()
finally:
conn.exec_driver_sql("DETACH DATABASE _td_dst")

# 3. Open the destination from the now-populated file. The standard
# constructor reflects the schema, restores pickled column types,
# and recomputes ``_max_id_per_time``.
return cls(**kwargs)
finally:
if selected is not None:
_drop_scratch_table(source_root._engine, selected)

def __getstate__(self) -> dict:
data_dict = self.__dict__.copy()
Expand Down
Loading
Loading