Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 192 additions & 7 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import binascii
import re
from collections.abc import Callable, Sequence
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar

import cloudpickle
Expand Down Expand Up @@ -1103,20 +1105,35 @@ def overlaps(
) -> list[list[int, 2]]:
"""
Get the overlaps between the nodes in `node_ids`.

When ``node_ids`` is provided, the query is split via
:meth:`_chunked_sa_read` to keep the number of bound parameters below
the backend's limit (notably SQLite's ``SQLITE_MAX_VARIABLE_NUMBER``).
Only the source side is constrained per-chunk; the target side is
filtered in Polars afterwards to avoid a quadratic blow-up of bound
parameters.
"""
if hasattr(node_ids, "tolist"):
node_ids = node_ids.tolist()

with Session(self._engine) as session:
query = session.query(self.Overlap.source_id, self.Overlap.target_id)
base_query = session.query(self.Overlap.source_id, self.Overlap.target_id)

if node_ids is not None:
query = query.filter(
self.Overlap.source_id.in_(node_ids),
self.Overlap.target_id.in_(node_ids),
)
if node_ids is None:
return [[source_id, target_id] for source_id, target_id in base_query.all()]

return [[source_id, target_id] for source_id, target_id in query.all()]
if len(node_ids) == 0:
return []

df = self._chunked_sa_read(
session,
lambda chunk: base_query.filter(self.Overlap.source_id.in_(chunk)),
node_ids,
self.Overlap,
)

df = df.filter(pl.col("target_id").is_in(node_ids))
return [[source_id, target_id] for source_id, target_id in df.iter_rows()]

def has_overlaps(self) -> bool:
"""
Expand Down Expand Up @@ -2055,6 +2072,174 @@ def dividing_nodes(self) -> list[int]:
with Session(self._engine) as session:
return [int(row[0]) for row in session.execute(stmt).all()]

@classmethod
def from_other(cls: type["SQLGraph"], other: "BaseGraph", **kwargs: Any) -> "SQLGraph":
"""
Create an :class:`SQLGraph` from another graph.

When the source is also SQL-backed (an :class:`SQLGraph` or a
:class:`GraphView` whose root is an :class:`SQLGraph`) and both source
and destination use the SQLite driver against on-disk databases, data
is copied at the SQL level via ``ATTACH DATABASE`` + ``INSERT INTO ...
SELECT`` rather than through Python. This bypasses the generic
:meth:`BaseGraph.from_other` path entirely, avoiding both the
per-statement variable limit (issue #285) and the cost of
materializing the full graph in memory.

For any other configuration (cross-dialect copy, ``:memory:``
destination, non-SQL source) this falls back to the generic
implementation.
"""
from tracksdata.graph._graph_view import GraphView

source_root: SQLGraph | None = None
source_node_ids: list[int] | None = None

if isinstance(other, SQLGraph):
source_root = other
elif isinstance(other, GraphView) and isinstance(other._root, SQLGraph):
source_root = other._root
source_node_ids = other.node_ids()

dst_database = kwargs.get("database")
dst_drivername = kwargs.get("drivername", "")

sqlite_dump_eligible = (
source_root is not None
and source_root._engine.dialect.name == "sqlite"
and isinstance(dst_drivername, str)
and dst_drivername.startswith("sqlite")
and isinstance(dst_database, str)
and dst_database not in ("", ":memory:")
and source_root._url.database not in (None, "", ":memory:")
and source_root._url.database != dst_database
# The dump replays the source's CREATE statements against the
# destination, so the destination file must start empty.
and (not Path(dst_database).exists() or Path(dst_database).stat().st_size == 0)
# ``overwrite=True`` would have the dst constructor drop tables
# we just populated; let the generic path handle it.
and not kwargs.get("overwrite", False)
)

if sqlite_dump_eligible:
return cls._sqlite_table_dump(
other=other,
source_root=source_root,
source_node_ids=source_node_ids,
kwargs=kwargs,
)

return super().from_other(other, **kwargs)

# Match the leading ``CREATE [UNIQUE] {TABLE|INDEX} [IF NOT EXISTS]``
# of a SQLite DDL statement so we can splice in an attached-database
# qualifier (``_td_dst.``) before the object name.
_SQLITE_DDL_QUALIFIER = re.compile(
r"^(\s*CREATE\s+(?:UNIQUE\s+)?(?:TABLE|INDEX)\s+(?:IF\s+NOT\s+EXISTS\s+)?)",
re.IGNORECASE,
)

@classmethod
def _sqlite_table_dump(
cls: type["SQLGraph"],
*,
other: "BaseGraph",
source_root: "SQLGraph",
source_node_ids: list[int] | None,
kwargs: dict[str, Any],
) -> "SQLGraph":
"""
Fast SQLite-to-SQLite copy via ``ATTACH DATABASE`` + raw DDL/DML.

Rather than instantiate the destination upfront and ALTER it into
shape, this dumps the source's schema and rows straight into the
destination file at the SQL level and then opens the destination
:class:`SQLGraph` from the populated file. The constructor's normal
reflection path then rebuilds the in-memory state.

For filtered copies (``source_node_ids`` not ``None``) the selection
is materialized in a temp table so the row filter joins instead of
using an oversized ``IN (...)`` clause that would hit SQLite's
bound-parameter limit.
"""
dst_database: str = kwargs["database"]
dst_path = Path(dst_database)

# The dump replays the source's CREATE statements against the
# destination, so the destination must start empty. The eligibility
# check in :meth:`from_other` already gates on this for the
# well-known cases; fall back defensively for anything else.
if dst_path.exists() and dst_path.stat().st_size > 0:
return super().from_other(other, **kwargs)

# ATTACH does not accept bound parameters in every SQLite build, so
# escape the path safely via single-quote doubling.
attach_path = dst_database.replace("'", "''")

with source_root._engine.connect() as conn:
conn.exec_driver_sql(f"ATTACH DATABASE '{attach_path}' AS _td_dst")
try:
# 1. Replicate the source schema by replaying its DDL against
# the attached destination. ``sqlite_master.sql`` is NULL for
# auto-generated objects (e.g. PK indexes), which we skip;
# tables are created before indexes.
ddl_rows = conn.exec_driver_sql(
"SELECT type, sql FROM main.sqlite_master "
"WHERE sql IS NOT NULL AND type IN ('table', 'index') "
"ORDER BY CASE type WHEN 'table' THEN 0 ELSE 1 END"
).fetchall()
for _type, ddl in ddl_rows:
qualified = cls._SQLITE_DDL_QUALIFIER.sub(r"\g<1>_td_dst.", ddl, count=1)
conn.exec_driver_sql(qualified)

# 2. Copy rows. The Metadata table is included verbatim — its
# SQL-private schema entries describe the columns we just
# cloned and so are valid for the destination as-is.
if source_node_ids is None:
for table_name in ("Node", "Edge", "Overlap", "Metadata"):
conn.exec_driver_sql(f'INSERT INTO _td_dst."{table_name}" SELECT * FROM main."{table_name}"')
else:
node_ids = list(source_node_ids)
if hasattr(node_ids, "tolist"):
node_ids = node_ids.tolist()
# Materialize the selection in a temp table so the row
# filter joins instead of using an oversized IN(...) clause.
conn.exec_driver_sql("CREATE TEMP TABLE _td_selected (node_id INTEGER PRIMARY KEY)")
insert_stmt = sa.text("INSERT INTO _td_selected (node_id) VALUES (:node_id)")
chunk_size = max(1, source_root._sql_chunk_size())
for i in range(0, len(node_ids), chunk_size):
batch = node_ids[i : i + chunk_size]
conn.execute(
insert_stmt,
[{"node_id": int(nid)} for nid in batch],
)

conn.exec_driver_sql(
'INSERT INTO _td_dst."Node" SELECT * FROM main."Node" '
"WHERE node_id IN (SELECT node_id FROM _td_selected)"
)
conn.exec_driver_sql(
'INSERT INTO _td_dst."Edge" SELECT * FROM main."Edge" '
"WHERE source_id IN (SELECT node_id FROM _td_selected) "
"AND target_id IN (SELECT node_id FROM _td_selected)"
)
conn.exec_driver_sql(
'INSERT INTO _td_dst."Overlap" SELECT * FROM main."Overlap" '
"WHERE source_id IN (SELECT node_id FROM _td_selected) "
"AND target_id IN (SELECT node_id FROM _td_selected)"
)
conn.exec_driver_sql('INSERT INTO _td_dst."Metadata" SELECT * FROM main."Metadata"')
conn.exec_driver_sql("DROP TABLE _td_selected")

conn.commit()
finally:
conn.exec_driver_sql("DETACH DATABASE _td_dst")

# 3. Open the destination from the now-populated file. The standard
# constructor reflects the schema, restores pickled column types,
# and recomputes ``_max_id_per_time``.
return cls(**kwargs)

def __getstate__(self) -> dict:
data_dict = self.__dict__.copy()
for k in ["Base", "Node", "Edge", "Overlap", "Metadata", "_engine"]:
Expand Down
Loading
Loading