Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion subsetter/config_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated

from subsetter.common import DatabaseConfig, SQLKnownOperator, SQLLiteralType
Expand Down Expand Up @@ -30,6 +30,19 @@ class ExtraFKConfig(ForbidBaseModel):
dst_table: str
dst_columns: List[str]

@model_validator(mode="after")
def check_columns_match(self):
col_count = len(self.src_columns)
if not col_count:
raise ValueError("src_columns cannot be empty")
if len(self.dst_columns) != col_count:
raise ValueError("src_columns and dst_columns must be the same length")
if len(set(self.src_columns)) != col_count:
raise ValueError("each column in src_columns must be unique")
if len(set(self.dst_columns)) != col_count:
raise ValueError("each column in src_columns must be unique")
return self

class ColumnConstraint(ForbidBaseModel):
column: str
operator: SQLKnownOperator
Expand Down
42 changes: 38 additions & 4 deletions subsetter/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,56 @@ def _solve_order(self) -> List[str]:

def _add_extra_fks(self) -> None:
"""Add in additional foreign keys requested."""
for extra_fk in self.config.extra_fks:
for index, extra_fk in enumerate(self.config.extra_fks):
src_schema, src_table_name = parse_table_name(extra_fk.src_table)
dst_schema, dst_table_name = parse_table_name(extra_fk.dst_table)
table = self.meta.tables.get((src_schema, src_table_name))
if table is None:
LOGGER.warning(
"Found no source table %s.%s referenced in add_extra_fks",
"Found no source table %s.%s referenced in extra_fks[%d]",
src_schema,
src_table_name,
index,
)
continue
if (dst_schema, dst_table_name) not in self.meta.tables:

src_missing_cols = {
col
for col in extra_fk.src_columns
if col not in table.table_obj.columns
}
if src_missing_cols:
LOGGER.warning(
"Columns %s do not exist in %s.%s referenced in extra_fks[%d]",
src_missing_cols,
src_schema,
src_table_name,
index,
)
continue

dst_table = self.meta.tables.get((dst_schema, dst_table_name))
if dst_table is None:
LOGGER.warning(
"Found no destination table %s.%s referenced in add_extra_fks[%d]",
dst_schema,
dst_table_name,
index,
)
continue

dst_missing_cols = {
col
for col in extra_fk.dst_columns
if col not in dst_table.table_obj.columns
}
if dst_missing_cols:
LOGGER.warning(
"Found no destination table %s.%s referenced in add_extra_fks",
"Columns %s do not exist in %s.%s referenced in extra_fks[%d]",
dst_missing_cols,
dst_schema,
dst_table_name,
index,
)
continue

Expand Down