Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion subsetter/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union

import sqlalchemy as sa
from pydantic import BaseModel
Expand Down Expand Up @@ -217,3 +217,35 @@ def _set_session_sqls(dbapi_connection, _):
cursor.close()

return engine


def pydantic_search(root: Any) -> Iterable[BaseModel]:
"""
A generator that yields all sub-models found underneath the passed root object (including the
root object itself). Searches model fields as well as through lists and dicts found in those
fields.
"""
vis = set()
stack = []

def _push(key: Any, value: Any):
if isinstance(value, (BaseModel, list, dict)):
if id(value) not in vis:
vis.add(id(value))
stack.append(value)

_push(None, root)
while stack:
data = stack.pop()
if isinstance(data, BaseModel):
yield data
for field, _ in data.model_fields.items():
_push(field, getattr(data, field))

if isinstance(data, list):
for idx, elem in enumerate(data):
_push(idx, elem)

if isinstance(data, dict):
for key, elem in data.items():
_push(key, elem)
55 changes: 55 additions & 0 deletions subsetter/plan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,21 @@ def simplify(self) -> "SQLWhereClause":
]


class SQLLeftJoin(BaseModel):
right: SQLTableIdentifier
left_columns: List[str]
right_columns: List[str]
half_unique: bool = True


class SQLStatementSelect(BaseModel):
type_: Literal["select"] = Field(..., alias="type")
columns: Optional[List[str]] = None
from_: SQLTableIdentifier = Field(..., alias="from")
where: Optional[SQLWhereClause] = None
limit: Optional[int] = None
joins: Optional[List[SQLLeftJoin]] = None
joins_outer: bool = False

model_config = ConfigDict(populate_by_name=True)

Expand All @@ -250,6 +259,48 @@ def build(self, context: SQLBuildContext):
else:
stmt = sa.select(table_obj)

if self.joins:
joined_cols: List[sa.ColumnExpression] = []
joined: sa.FromClause = table_obj
exists_constraints: List[sa.ColumnExpressionArgument] = []
for join in self.joins: # pylint: disable=not-an-iterable
right = join.right.build(context).alias()

if join.half_unique:
joined = joined.join(
right,
onclause=sa.and_(
*(
table_obj.c[lft_col] == right.c[rht_col]
for lft_col, rht_col in zip(
join.left_columns, join.right_columns
)
)
),
isouter=self.joins_outer,
)
joined_cols.extend(
right.c[rht_col] for rht_col in join.right_columns
)
else:
exists_constraints.append(
sa.exists().where(
*(
table_obj.c[lft_col] == right.c[rht_col]
for lft_col, rht_col in zip(
join.left_columns, join.right_columns
)
)
)
)

stmt = stmt.select_from(joined).distinct()
if self.joins_outer:
exists_constraints.extend(col.is_not(None) for col in joined_cols)
stmt = stmt.where(sa.or_(*exists_constraints))
elif exists_constraints:
stmt = stmt.where(sa.and_(*exists_constraints))

if self.where:
stmt = stmt.where(self.where.build(context, table_obj))

Expand All @@ -273,6 +324,10 @@ def simplify(self) -> "SQLStatementSelect":
kwargs["columns"] = self.columns
if self.limit is not None:
kwargs["limit"] = self.limit
if self.joins:
kwargs["joins"] = self.joins
kwargs["joins_outer"] = self.joins_outer

return SQLStatementSelect(**kwargs) # type: ignore


Expand Down
87 changes: 45 additions & 42 deletions subsetter/planner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple

import sqlalchemy as sa

from subsetter.common import DatabaseConfig, parse_table_name
from subsetter.config_model import PlannerConfig
from subsetter.metadata import DatabaseMetadata, ForeignKey, TableMetadata
from subsetter.plan_model import (
SQLLeftJoin,
SQLStatementSelect,
SQLStatementUnion,
SQLTableIdentifier,
Expand All @@ -13,7 +16,6 @@
SQLWhereClauseAnd,
SQLWhereClauseIn,
SQLWhereClauseOperator,
SQLWhereClauseOr,
SQLWhereClauseRandom,
SQLWhereClauseSQL,
SubsetPlan,
Expand Down Expand Up @@ -267,8 +269,6 @@ def _plan_table(
processed: Set[Tuple[str, str]],
target: Optional[PlannerConfig.TargetConfig] = None,
) -> SQLTableQuery:
fk_constraints: List[SQLWhereClause] = []

foreign_keys = sorted(
fk
for fk in table.foreign_keys
Expand Down Expand Up @@ -311,33 +311,34 @@ def _plan_table(
[f"{fk.dst_schema}.{fk.dst_table}" for fk in rev_foreign_keys],
)

fk_constraints = [
SQLWhereClauseIn(
type_="in",
columns=list(fk.columns),
values=SQLStatementSelect(
type_="select",
columns=list(fk.dst_columns),
from_=SQLTableIdentifier(
def _is_distinct(table_obj: sa.Table, cols: Iterable[str]) -> bool:
cols_st = set(cols)
for constraint in table_obj.constraints:
if isinstance(
constraint, (sa.PrimaryKeyConstraint, sa.UniqueConstraint)
):
constraint_cols = set(col.name for col in constraint.columns)
if constraint_cols <= cols_st:
return True
return False

fk_joins = []
for fk in foreign_keys or rev_foreign_keys:
dst_table = self.meta.tables[(fk.dst_schema, fk.dst_table)]
half_unique = _is_distinct(table.table_obj, fk.columns) or _is_distinct(
dst_table.table_obj, fk.dst_columns
)
fk_joins.append(
SQLLeftJoin(
right=SQLTableIdentifier(
table_schema=fk.dst_schema,
table_name=fk.dst_table,
sampled=True,
),
),
)
for fk in foreign_keys or rev_foreign_keys
]

fk_constraint: SQLWhereClause
if foreign_keys:
fk_constraint = SQLWhereClauseAnd(
type_="and",
conditions=fk_constraints,
)
else:
fk_constraint = SQLWhereClauseOr(
type_="or",
conditions=fk_constraints,
left_columns=list(fk.columns),
right_columns=list(fk.dst_columns),
half_unique=half_unique,
)
)

conf_constraints = self.config.table_constraints.get(
Expand Down Expand Up @@ -365,23 +366,25 @@ def _plan_table(
)
)

statements: List[SQLStatementSelect] = []

# Calculate initial foreign-key / config constraint statement
statements: List[SQLStatementSelect] = [
SQLStatementSelect(
type_="select",
from_=SQLTableIdentifier(
table_schema=table.schema,
table_name=table.name,
),
where=SQLWhereClauseAnd(
type_="and",
conditions=[
*conf_constraints_sql,
fk_constraint,
],
),
if foreign_keys or rev_foreign_keys:
statements.append(
SQLStatementSelect(
type_="select",
from_=SQLTableIdentifier(
table_schema=table.schema,
table_name=table.name,
),
joins=fk_joins,
joins_outer=not foreign_keys,
where=SQLWhereClauseAnd(
type_="and",
conditions=conf_constraints_sql,
),
)
)
]

# If targetted also calculate target constraint statement
if target:
Expand Down
57 changes: 42 additions & 15 deletions subsetter/sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import collections
import functools
import json
import logging
Expand All @@ -12,7 +13,7 @@
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import ClauseElement, Executable

from subsetter.common import DatabaseConfig, parse_table_name
from subsetter.common import DatabaseConfig, parse_table_name, pydantic_search
from subsetter.config_model import (
ConflictStrategy,
DatabaseOutputConfig,
Expand All @@ -21,7 +22,7 @@
)
from subsetter.filters import FilterOmit, FilterView, FilterViewChain
from subsetter.metadata import DatabaseMetadata
from subsetter.plan_model import SQLTableIdentifier
from subsetter.plan_model import SQLLeftJoin, SQLTableIdentifier
from subsetter.planner import SubsetPlan
from subsetter.solver import toposort

Expand Down Expand Up @@ -69,7 +70,7 @@ def create(
select: sa.Select,
*,
name: str = "",
primary_key: Tuple[str, ...] = (),
indexes: Iterable[Tuple[str, ...]] = (),
) -> Tuple[sa.Table, int]:
"""
Create a temporary table on the passed connection generated by the passed
Expand All @@ -82,9 +83,8 @@ def create(
schema: The schema to create the temporary table within. For some dialects
temporary tables always exist in their own schema and this parameter
will be ignored.
primary_key: If set will mark the set of columns passed as primary keys in
the temporary table. This tuple should match a subset of the
column names in the select query.
indexes: creates an index on each tuple of columns listed. This is useful
if future queries are likely to reference these columns.

Returns a tuple containing the generated table object and the number of rows that
were inserted in the table.
Expand All @@ -106,10 +106,7 @@ def create(
metadata,
schema=temp_schema,
prefixes=["TEMPORARY"],
*(
sa.Column(col.name, col.type, primary_key=col.name in primary_key)
for col in select.selected_columns
),
*(sa.Column(col.name, col.type) for col in select.selected_columns),
)
try:
metadata.create_all(conn)
Expand All @@ -122,11 +119,29 @@ def create(
if "--read-only" not in str(exc):
raise

for idx, index_cols in enumerate(indexes):
# For some dialects/data types we may not be able to construct an index. We just do our
# best here instead of hard failing.
try:
sa.Index(
f"{temp_name}_idx_{idx}",
*(table_obj.columns[col_name] for col_name in index_cols),
).create(bind=conn)
except sa.exc.OperationalError:
LOGGER.warning(
"Failed to create index %s on temporary table %s",
index_cols,
temp_name,
exc_info=True,
)

# Copy data into the temporary table
result = conn.execute(
table_obj.insert().from_select(list(table_obj.columns), select)
stmt = table_obj.insert().from_select(list(table_obj.columns), select)
LOGGER.debug(
" Using statement %s",
str(stmt.compile(dialect=conn.engine.dialect)).replace("\n", " "),
)
result = conn.execute(table_obj.select())
result = conn.execute(stmt)

return table_obj, result.rowcount

Expand Down Expand Up @@ -832,6 +847,18 @@ def _materialize_tables(
conn: sa.Connection,
plan: SubsetPlan,
) -> None:
# Figure out what sets of columns are going to be queried for our materialized tables.
joined_columns = collections.defaultdict(set)
for data in pydantic_search(plan):
if not isinstance(data, SQLLeftJoin):
continue
table_id = data.right
if not table_id.sampled:
continue
joined_columns[(table_id.table_schema, table_id.table_name)].add(
tuple(data.right_columns)
)

materialization_order = self._materialization_order(meta, plan)
for schema, table_name, ref_count in materialization_order:
table = meta.tables[(schema, table_name)]
Expand Down Expand Up @@ -864,7 +891,7 @@ def _materialize_tables(
schema,
table_q,
name=table_name,
primary_key=table.primary_key,
indexes=joined_columns[(schema, table_name)],
)
)
self.cached_table_sizes[(schema, table_name)] = rowcount
Expand All @@ -887,7 +914,7 @@ def _materialize_tables(
schema,
meta.temp_tables[(schema, table_name, 0)].select(),
name=table_name,
primary_key=table.primary_key,
indexes=joined_columns[(schema, table_name)],
)
)
LOGGER.info(
Expand Down
Loading