Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
35da209
Automatic pre-commit fixes
Oct 3, 2025
bb14ced
Fixed a test
Oct 3, 2025
cda6164
base and create mypy fixes
Oct 3, 2025
e63c8bc
Some mypy fixes to generator.py
Oct 3, 2025
f28d7b9
Fixed variances in tests
Oct 5, 2025
20ff9ed
Mypy fixes in generators.py
Oct 6, 2025
49954de
mypy fixes in interactive.py
Oct 6, 2025
3e01c57
More mypy fixes in interactive.py
Oct 6, 2025
ea02cf3
mypy clean: dump, generators, interactive, providers
Oct 7, 2025
63f781e
Mypy fixed dump, interactive, main, serialize_metadata
Oct 7, 2025
3a5527b
mypy clean in datafaker dir
Oct 7, 2025
3fffbad
pre-commit rewrites
Oct 7, 2025
55acf14
test_dump is mypy clean
Oct 7, 2025
c3709b8
Some mypy cleaning of tests directory
Oct 7, 2025
113c4d2
Much more cleaning. mypy clean
Oct 8, 2025
10b02c5
precommit cleanup, NullPartitionedGrouped fix
Oct 8, 2025
42fb24a
Many, many cleanups.
Oct 9, 2025
b86e106
More cleaning
Oct 9, 2025
e1dec20
Lots of pylint cleaning
Oct 10, 2025
2894044
Precommit clean!
Oct 13, 2025
7728e6a
Merge remote-tracking branch 'safehr/main' into precommit
Oct 13, 2025
2a4982f
Pre-commit cleaned.
Oct 15, 2025
05ea378
Add running tests to pre-commit.yml
Oct 15, 2025
91036ce
Github actions starting PostgreSQL
Oct 15, 2025
69e0933
Fixed test_unique_constraint_fails
Oct 16, 2025
820700e
Cleaned up
Oct 16, 2025
433bb19
Fixed tests
Oct 16, 2025
d1b07dc
cleaned
Oct 16, 2025
79990a1
Move real test runner to tests.yml, overwriting bad test runner
Oct 17, 2025
9d0ec47
Added poetry initialisation to test runner
Oct 17, 2025
83438bb
More test fixes
Oct 17, 2025
2306b12
Another attempt to get tests.yml working
Oct 17, 2025
3c1c9aa
Initial attempt at a static version of df.py
Oct 22, 2025
47a0767
Merge branch 'main' of github:SAFEHR-data/datafaker into remove-gener…
Mar 17, 2026
6cde9fd
A few updates.
Mar 17, 2026
00df596
First test for create-data without intermediate file
Mar 19, 2026
ba9bf59
test_workflow_minimal_args passes
Mar 20, 2026
edda835
All tests pass.
Mar 23, 2026
c1c2858
A few pre-commit fixes
Mar 23, 2026
f914804
Cleaned pre-commit checks
Mar 24, 2026
176ed63
Merge branch 'main' into remove-generator-file
Mar 24, 2026
675082f
Version bump to 0.3.0
Mar 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions datafaker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import gzip
import os
import random
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from io import TextIOWrapper
Expand All @@ -12,7 +11,7 @@
import yaml
from sqlalchemy import Connection, insert
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.schema import MetaData, Table
from sqlalchemy.schema import Table

from datafaker.utils import (
MAKE_VOCAB_PROGRESS_REPORT_EVERY,
Expand All @@ -22,24 +21,6 @@
)


class TableGenerator(ABC):
"""Abstract base class for table generator classes."""

num_rows_per_pass: int = 1

@abstractmethod
def __call__(self, dst_db_conn: Connection, metadata: MetaData) -> dict[str, Any]:
"""Return, as a dictionary, a new row for the table that we are generating.

The only argument, `dst_db_conn`, should be a database connection to the
database to which the data is being written. Most generators won't use it, but
some do, and thus it's required by the interface.

The return value should be a dictionary with column names as strings for keys,
and the values being the values for the new row.
"""


@dataclass
class FileUploader:
"""For uploading data files."""
Expand Down
189 changes: 131 additions & 58 deletions datafaker/create.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
"""Functions and classes to create and populate the target database."""
import pathlib
from collections import Counter
from types import ModuleType
from pathlib import Path
from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple

import typer
import yaml
from sqlalchemy import Connection, insert, inspect
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table

from datafaker.base import FileUploader, TableGenerator
from datafaker.settings import get_destination_dsn, get_destination_schema
from datafaker.base import FileUploader
from datafaker.make import FunctionCall, StoryGeneratorInfo, get_generation_info
from datafaker.populate import (
TableGenerator,
call_function,
get_symbols,
get_table_generator_dict,
)
from datafaker.settings import get_destination_dsn, get_destination_schema, get_settings
from datafaker.utils import (
create_db_engine_dst,
get_property,
get_sync_engine,
get_vocabulary_table_names,
logger,
Expand Down Expand Up @@ -92,7 +101,7 @@ def create_db_vocab(
metadata: MetaData,
meta_dict: dict[str, Any],
config: Mapping,
base_path: pathlib.Path = pathlib.Path("."),
base_path: Path = Path("."),
) -> list[str]:
"""
Load vocabulary tables from files.
Expand All @@ -102,6 +111,10 @@ def create_db_vocab(
:param config: The configuration from --config-file
:return: List of table names loaded.
"""
settings = get_settings()
dst_dsn: str = settings.dst_dsn or ""
assert dst_dsn != "", "Missing DST_DSN setting."

dst_engine = get_sync_engine(
create_db_engine_dst(
get_destination_dsn(),
Expand Down Expand Up @@ -136,14 +149,28 @@ def create_db_vocab(

def create_db_data(
sorted_tables: Sequence[Table],
df_module: ModuleType,
config: Mapping[str, Any],
src_stats_filename: Path | None,
num_passes: int,
metadata: MetaData,
) -> RowCounts:
"""Connect to a database and populate it with data."""
if src_stats_filename:
try:
with src_stats_filename.open(encoding="utf-8") as fh:
src_stats = yaml.load(fh, yaml.SafeLoader)
except FileNotFoundError as exc:
logger.error(
"No source stats file '%s', this should be the output of the 'make-stats' command",
src_stats_filename,
)
raise typer.Exit(1) from exc
else:
src_stats = None
return create_db_data_into(
sorted_tables,
df_module,
config,
src_stats,
num_passes,
get_destination_dsn(),
get_destination_schema(),
Expand All @@ -154,7 +181,8 @@ def create_db_data(
# pylint: disable=too-many-arguments too-many-positional-arguments
def create_db_data_into(
sorted_tables: Sequence[Table],
df_module: ModuleType,
config: Mapping[str, Any],
src_stats: dict[str, dict[str, Any]] | None,
num_passes: int,
db_dsn: str,
schema_name: str | None,
Expand All @@ -165,62 +193,113 @@ def create_db_data_into(

:param sorted_tables: The table names to populate, sorted so that foreign
keys' targets are populated before the foreign keys themselves.
:param table_generator_dict: A mapping of table names to the generators
used to make data for them.
:param story_generator_list: A list of story generators to be run after the
table generators on each pass.
:param config: The data from the ``config.yaml`` file.
:param src_stats: The data from the ``src-stats.yaml`` file.
:param num_passes: Number of passes to perform.
:param db_dsn: Connection string for the destination database.
:param schema_name: Destination schema name.
:param metadata: Destination database metadata.
"""
dst_engine = get_sync_engine(create_db_engine_dst(db_dsn, schema_name=schema_name))

gen_info = get_generation_info(metadata, config)
context = get_symbols(
gen_info.row_generator_module_name,
gen_info.story_generator_module_name,
get_property(config, "object_instantiation", {}),
src_stats,
metadata,
)
row_counts: Counter[str] = Counter()
with dst_engine.connect() as dst_conn:
context["dst_db_conn"] = dst_conn
for _ in range(num_passes):
row_counts += populate(
dst_conn,
sorted_tables,
df_module.table_generator_dict,
df_module.story_generator_list,
metadata,
get_table_generator_dict(
dst_conn,
gen_info.tables,
gen_info.max_unique_constraint_tries,
context,
),
gen_info.story_generators,
context,
)
dst_engine.dispose()
return row_counts


def empty_story_generator() -> (
Generator[tuple[str, dict[str, Any]], dict[str, Any], None]
):
"""Get a story generator that generates no values."""
empt: list[tuple[str, dict[str, Any]]] = []
yield from empt


# pylint: disable=too-many-instance-attributes
class StoryIterator:
"""Iterates through all the rows produced by all the stories."""

def __init__(
self,
stories: Iterable[tuple[str, Story]],
stories: Iterable[StoryGeneratorInfo],
table_dict: Mapping[str, Table],
table_generator_dict: Mapping[str, TableGenerator],
dst_conn: Connection,
context: Mapping,
):
"""Initialise a Story Iterator."""
self._stories: Iterator[tuple[str, Story]] = iter(stories)
self._story_infos: Iterator[StoryGeneratorInfo] = iter(stories)
self._table_dict: Mapping[str, Table] = table_dict
self._table_generator_dict: Mapping[str, TableGenerator] = table_generator_dict
self._dst_conn: Connection = dst_conn
self._table_name: str | None
self._table_name: str | None = None
self._final_values: dict[str, Any] | None = None
# Number of times the current story should be run
self._story_counts = 1
self._story_function_call: FunctionCall
self._context = context
self._story = empty_story_generator()
self._provided_values: dict[str, Any]
self.next()

def _get_next_story(self) -> bool:
"""
Iterate to the next ``_story_infos``.

:return: False if there are no more.
"""
try:
name, self._story = next(self._stories)
logger.info("Generating data for story '%s'", name)
self._table_name, self._provided_values = next(self._story)
sgi = next(self._story_infos)
self._story_counts = sgi.num_stories_per_pass
self._story_function_call = sgi.function_call
logger.info(
"Generating data for story '%s'", sgi.function_call.function_name
)
self._story = call_function(sgi.function_call, self._context)
self._final_values = None
except StopIteration:
self._table_name = None
return False
return True

def _get_values(self) -> None:
"""Get the values from the current story and advance the iterator."""
if self._final_values is None:
self._table_name, self._provided_values = next(self._story)
else:
self._table_name, self._provided_values = self._story.send(
self._final_values
)

def is_ended(self) -> bool:
"""
Check if we have another row to process.

If so, insert() can be called.
"""
return self._table_name is None
return self._story_counts == -1

def has_table(self, table_name: str) -> bool:
"""Check if we have a row for table ``table_name``."""
Expand All @@ -235,7 +314,7 @@ def table_name(self) -> str | None:
"""
return self._table_name

def insert(self, metadata: MetaData) -> None:
def insert(self) -> None:
"""
Put the row in the table.

Expand All @@ -247,7 +326,7 @@ def insert(self, metadata: MetaData) -> None:
table = self._table_dict[self._table_name]
if table.name in self._table_generator_dict:
table_generator = self._table_generator_dict[table.name]
default_values = table_generator(self._dst_conn, metadata)
default_values = table_generator(self._dst_conn)
else:
default_values = {}
insert_values = {**default_values, **self._provided_values}
Expand All @@ -271,54 +350,48 @@ def next(self) -> None:
"""Advance to the next row."""
while True:
try:
if self._final_values is None:
self._table_name, self._provided_values = next(self._story)
return
self._table_name, self._provided_values = self._story.send(
self._final_values
)
self._get_values()
return
except StopIteration:
try:
name, self._story = next(self._stories)
logger.info("Generating data for story '%s'", name)
self._final_values = None
except StopIteration:
self._table_name = None
self._final_values = None
self._story_counts -= 1
if 0 < self._story_counts:
# Reinitialize the same story again
self._story = call_function(
self._story_function_call, self._context
)
elif not self._get_next_story():
self._story_counts = -1
return


def populate(
dst_conn: Connection,
tables: Sequence[Table],
table_generator_dict: Mapping[str, TableGenerator],
story_generator_list: Sequence[Mapping[str, Any]],
metadata: MetaData,
story_generator_infos: Sequence[StoryGeneratorInfo],
context: Mapping,
) -> RowCounts:
"""Populate a database schema with synthetic data."""
row_counts: Counter[str] = Counter()
table_dict = {table.name: table for table in tables}
# Generate stories
# Each story generator returns a python generator (an unfortunate naming clash with
# what we call generators). Iterating over it yields individual rows for the
# database. First, collect all of the python generators into a single list.
stories: list[tuple[str, Story]] = sum(
[
[
(sg["name"], sg["function"](dst_conn))
for _ in range(sg["num_stories_per_pass"])
]
for sg in story_generator_list
],
[],
# database.
story_iterator = StoryIterator(
story_generator_infos,
table_dict,
table_generator_dict,
dst_conn,
context,
)
story_iterator = StoryIterator(stories, table_dict, table_generator_dict, dst_conn)

# Generate individual rows, table by table.
for table in tables:
# Do we have a story row to enter into this table?
if story_iterator.has_table(table.name):
story_iterator.insert(metadata)
story_iterator.insert()
row_counts[table.name] = row_counts.get(table.name, 0) + 1
story_iterator.next()
if table.name not in table_generator_dict:
Expand All @@ -329,20 +402,20 @@ def populate(
continue
logger.debug("Generating data for table '%s'", table.name)
# Run all the inserts for one table in a transaction
try:
with dst_conn.begin():
with dst_conn.begin():
try:
for _ in range(table_generator.num_rows_per_pass):
stmt = insert(table).values(table_generator(dst_conn, metadata))
stmt = insert(table).values(table_generator(dst_conn))
dst_conn.execute(stmt)
row_counts[table.name] = row_counts.get(table.name, 0) + 1
dst_conn.commit()
except:
dst_conn.rollback()
raise
except:
dst_conn.rollback()
raise

# Insert any remaining stories
while not story_iterator.is_ended():
story_iterator.insert(metadata)
story_iterator.insert()
t = story_iterator.table_name()
if t is None:
raise AssertionError(
Expand Down
2 changes: 1 addition & 1 deletion datafaker/generators/partitioned.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def get_named_tables(self) -> Mapping[str, str]:

def __init__(self, config: Mapping[str, Any]) -> None:
"""Initialize the null partitioned generator factory."""
tables = get_property(config, "tables", dict, {})
tables: dict[str, Any] = get_property(config, "tables", {})
self._named_tables = {
table_name: table_conf["name_column"]
for table_name, table_conf in tables.items()
Expand Down
Loading
Loading