Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions datafaker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _check_file_non_existence(file_path: Path) -> None:
"""Check that a given file does not exist. Exit with an error message if it does."""
if file_path.exists():
logger.error("%s should not already exist. Exiting...", file_path)
sys.exit(1)
raise Exit(1)


def load_metadata_config(
Expand Down Expand Up @@ -580,7 +580,7 @@ def convert_table_names_to_tables(
"%s is not the name of a table in the destination database", name
)
if failed_count:
sys.exit(1)
raise Exit(1)
return results


Expand Down Expand Up @@ -669,7 +669,7 @@ def dump_data(
"Must specify exactly one table if the output name is"
" specified, or specify an existing directory"
)
sys.exit(1)
raise Exit(1)
dst_dsn = get_destination_dsn()
schema_name = get_destination_schema()
config = read_config_file(config_file) if config_file is not None else {}
Expand Down Expand Up @@ -702,7 +702,7 @@ def validate_config(
validate(config, schema_config)
except ValidationError as e:
logger.error(e)
sys.exit(1)
raise Exit(1) from e
logger.debug("Config file is valid.")


Expand Down Expand Up @@ -798,7 +798,7 @@ def remove_tables(
except InternalError as exc:
logger.error("Failed to drop tables: %s", exc)
logger.error("Please try again using the --all option.")
sys.exit(1)
raise Exit(1) from exc
logger.debug("Tables dropped.")
else:
logger.info("Would remove tables if called with --yes.")
Expand Down
19 changes: 17 additions & 2 deletions datafaker/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pandas as pd
import snsql
import typer
import yaml
from black import FileMode, format_str
from jinja2 import Environment, FileSystemLoader, Template
Expand All @@ -31,6 +32,7 @@
create_db_engine,
download_table,
get_columns_assigned,
get_metadata,
get_property,
get_related_table_names,
get_row_generators,
Expand Down Expand Up @@ -606,7 +608,21 @@ def make_table_generators( # pylint: disable=too-many-locals
:return: A string that is a valid Python module, once written to file.
"""
row_generator_module_name: str = config.get("row_generators_module", None)
if row_generator_module_name and "-" in row_generator_module_name:
logger.error(
"Row generator name %s specified in %s should not contain a hyphen",
row_generator_module_name,
config_filename,
)
raise typer.Exit(1)
story_generator_module_name = config.get("story_generators_module", None)
if story_generator_module_name and "-" in story_generator_module_name:
logger.error(
"Story generator name %s specified in %s should not contain a hyphen",
story_generator_module_name,
config_filename,
)
raise typer.Exit(1)
object_instantiation: dict[str, dict] = config.get("object_instantiation", {})
tables_config = config.get("tables", {})

Expand Down Expand Up @@ -703,8 +719,7 @@ def make_tables_file(
"""Construct the YAML file representing the schema."""
engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name))

metadata = MetaData()
metadata.reflect(engine)
metadata = get_metadata(engine)
meta_dict = metadata_to_dict(metadata, schema_name, engine, parquet_dir)

if parquet_dir is not None:
Expand Down
4 changes: 2 additions & 2 deletions datafaker/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datafaker.settings import get_destination_dsn, get_destination_schema
from datafaker.utils import (
create_db_engine,
get_metadata,
get_sync_engine,
get_vocabulary_table_names,
logger,
Expand Down Expand Up @@ -67,6 +68,5 @@ def remove_db_tables(metadata: Optional[MetaData]) -> None:
)
)
if metadata is None:
metadata = MetaData()
metadata.reflect(dst_engine)
metadata = get_metadata(dst_engine)
metadata.drop_all(dst_engine)
43 changes: 36 additions & 7 deletions datafaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from jsonschema.validators import validate
from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select
from sqlalchemy.engine.interfaces import DBAPIConnection
from sqlalchemy.exc import IntegrityError, ProgrammingError
from sqlalchemy.exc import (
IntegrityError,
NoSuchModuleError,
OperationalError,
ProgrammingError,
)
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.schema import (
Expand All @@ -43,6 +48,7 @@
MetaData,
Table,
)
from typer import Exit

# Define some types used repeatedly in the code base
MaybeAsyncEngine = Union[Engine, AsyncEngine]
Expand Down Expand Up @@ -110,7 +116,11 @@ def import_file(file_path: str) -> ModuleType:
if spec is None or spec.loader is None:
raise ImportError(f"No loadable module at {file_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
try:
spec.loader.exec_module(module)
except ModuleNotFoundError as e:
logger.error("Failed to load module at %s with error:", file_path)
logger.error(e)
return module


Expand Down Expand Up @@ -193,11 +203,19 @@ def create_db_engine(
**kwargs: Any,
) -> MaybeAsyncEngine:
"""Create a SQLAlchemy Engine."""
if use_asyncio:
async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://")
engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs)
else:
engine = create_engine(db_dsn, **kwargs)
try:
if use_asyncio:
async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://")
engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs)
else:
engine = create_engine(db_dsn, **kwargs)
except NoSuchModuleError as exc:
logger.error("Failed to connect to the database: %s", exc)
logger.error("Perhaps the dialect '%s' is invalid.", db_dsn.split(":")[0])
raise Exit(1) from exc
except ValueError as exc:
logger.error("DSN %s is malformed: %s", db_dsn, exc)
raise Exit(1) from exc

settings = {}
if schema_name is not None:
Expand Down Expand Up @@ -248,6 +266,17 @@ def create_db_engine_dst(
return create_db_engine(db_dsn, schema_name, use_asyncio)


def get_metadata(engine: Engine) -> MetaData:
"""Get the MetaData object associated with the engine passed."""
md = MetaData()
try:
md.reflect(engine)
except OperationalError as exc:
logger.error("Cannot connect to database: %s", exc)
raise Exit(1) from exc
return md


def _find_parquet_directories(parquet_dir: Path) -> list[str]:
"""Find all the directories under ``parquet_dir`` that contain parquet files."""
return [
Expand Down
76 changes: 76 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Mapping

import yaml
from sqlalchemy import create_engine, inspect
from typer.testing import CliRunner, Result

Expand Down Expand Up @@ -606,6 +607,81 @@ def test_create_schema(self) -> None:
inspector = inspect(engine)
self.assertTrue(inspector.has_schema(env["dst_schema"]))

def test_story_incorrect_name(self) -> None:
"""Test we get a proper error message if the story generator module does not exist."""
config_file = "config_story_incorrect.yaml"
config = {
"story_generators_module": "incorrect_module",
}
with Path(config_file).open("w", encoding="utf-8") as fh:
fh.write(yaml.dump(config))
self.invoke(
"make-tables",
"--force",
)
completed_process = self.invoke(
"create-generators",
"--force",
"--config-file",
config_file,
)
self.assertSuccess(completed_process)
self.invoke(
"create-tables",
"--config-file",
config_file,
)
self.assertSuccess(completed_process)
completed_process = self.invoke(
"create-data",
"--config-file",
config_file,
expected_error="No module named 'incorrect_module'",
)
self.assertReturnCode(completed_process, 1)

def test_story_hyphens_in_name(self) -> None:
"""Test hyphens in story generator names cause an error to be emitted."""
config_file = "config_story_hyphens.yaml"
config = {
"story_generators_module": "story-generators",
}
with Path(config_file).open("w", encoding="utf-8") as fh:
fh.write(yaml.dump(config))
self.invoke(
"make-tables",
"--force",
)
completed_process = self.invoke(
"create-generators",
"--force",
"--config-file",
config_file,
expected_error="hyphen",
)
self.assertReturnCode(completed_process, 1)

def test_row_hyphens_in_name(self) -> None:
"""Test hyphens in row generator names cause an error to be emitted."""
config_file = "config_row_hyphens.yaml"
config = {
"row_generators_module": "row-generators",
}
with Path(config_file).open("w", encoding="utf-8") as fh:
fh.write(yaml.dump(config))
self.invoke(
"make-tables",
"--force",
)
completed_process = self.invoke(
"create-generators",
"--force",
"--config-file",
config_file,
expected_error="hyphen",
)
self.assertReturnCode(completed_process, 1)


class DuckDbFunctionalTestCase(DBFunctionalTestCaseBase):
"""End-to-end tests for the DuckDB workflow."""
Expand Down
84 changes: 84 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,90 @@ def test_make_tables_with_force_enabled(
mock_make_tables.reset_mock()
mock_path.reset_mock()

@patch("datafaker.main.Path")
@patch("datafaker.settings.get_settings")
def test_incorrect_dialect_causes_nice_error_message(
self,
mock_get_settings: MagicMock,
mock_path: MagicMock,
) -> None:
"""Test the make-tables sub-command, when the force option is activated."""
mock_get_settings.return_value = Settings(
# postgres: not postgresql: will cause sqlalchemy to fail to connect
src_dsn="postgres://suser:spassword@shost:5432/sdbname",
dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname",
# To stop any local .env files influencing the test
# The mypy ignore can be removed once we upgrade to pydantic 2.
_env_file=None, # type: ignore[call-arg]
)
mock_path.return_value.exists.return_value = True

result = runner.invoke(
app,
[
"make-tables",
"--force",
"--orm-file=tests/examples/example_orm.yaml",
],
)
self.assertIs(type(result.exception), SystemExit)

@patch("datafaker.main.Path")
@patch("datafaker.settings.get_settings")
def test_invalid_host_causes_nice_error_message(
self,
mock_get_settings: MagicMock,
mock_path: MagicMock,
) -> None:
"""Test the make-tables sub-command, when the force option is activated."""
mock_get_settings.return_value = Settings(
# postgres: not postgresql: will cause sqlalchemy to fail to connect
src_dsn="postgresql://suser:spassword@invalid_host:5432/sdbname",
dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname",
# To stop any local .env files influencing the test
# The mypy ignore can be removed once we upgrade to pydantic 2.
_env_file=None, # type: ignore[call-arg]
)
mock_path.return_value.exists.return_value = True

result = runner.invoke(
app,
[
"make-tables",
"--force",
"--orm-file=tests/examples/example_orm.yaml",
],
)
self.assertIs(type(result.exception), SystemExit)

@patch("datafaker.main.Path")
@patch("datafaker.settings.get_settings")
def test_incorrect_dsn_causes_nice_error_message(
self,
mock_get_settings: MagicMock,
mock_path: MagicMock,
) -> None:
"""Test the make-tables sub-command, when the force option is activated."""
mock_get_settings.return_value = Settings(
# postgres: not postgresql: will cause sqlalchemy to fail to connect
src_dsn="postgresql://suser:spassword:localhost:5432/sdbname",
dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname",
# To stop any local .env files influencing the test
# The mypy ignore can be removed once we upgrade to pydantic 2.
_env_file=None, # type: ignore[call-arg]
)
mock_path.return_value.exists.return_value = True

result = runner.invoke(
app,
[
"make-tables",
"--force",
"--orm-file=tests/examples/example_orm.yaml",
],
)
self.assertIs(type(result.exception), SystemExit)

def test_validate_config(self) -> None:
"""Test the validate-config sub-command."""
result = runner.invoke(
Expand Down
4 changes: 0 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@
)


class SysExit(Exception):
"""To force the function to exit as sys.exit() would."""


@lru_cache(1)
def get_test_settings() -> settings.Settings:
"""Get a Settings object that ignores .env files and environment variables."""
Expand Down
Loading