Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions docs/plugins/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,9 @@ Users set that alias from default model settings or from `DataDesignerConfigBuil

If your plugin uses multiple model aliases, inherit from `ColumnGeneratorWithModelRegistry` and resolve each alias explicitly with `self.get_model(...)`.

The config must include a primary `model_alias: str` field. Startup health checks read it directly from any column config whose generator inherits from `ColumnGeneratorWithModelRegistry`, including generators that inherit through `ColumnGeneratorWithModel`. A config for this pattern might also define `judge_model_alias`, `critic_model_alias`, or another task-specific alias.
The startup model health check pings every alias your column declares. By default, `SingleColumnConfig.get_model_aliases()` returns the primary `model_alias` field, which covers single-model plugins for free. A config for this pattern might also define `judge_model_alias`, `critic_model_alias`, or another task-specific alias. Override `get_model_aliases()` to return every alias the column depends on so a typo, missing API key, or unreachable endpoint surfaces at startup instead of at first generation.

Validate additional alias fields in `_validate()` or `_initialize()` with `get_model_config(...)` so missing aliases fail before generation starts. `get_model_config(alias)` only verifies that the alias is registered; it does not call the endpoint. Endpoint reachability is only exercised for the primary `model_alias` collected by the standard startup health check.

The matching config shows which alias gets the standard startup health check and which alias the plugin validates itself:
The matching config opts every alias into the standard startup health check by listing them all in `get_model_aliases()`:

```python
from __future__ import annotations
Expand All @@ -111,6 +109,9 @@ class PairwiseJudgeColumnConfig(SingleColumnConfig):
@property
def side_effect_columns(self) -> list[str]:
return []

def get_model_aliases(self) -> list[str]:
return [self.model_alias, self.judge_model_alias]
```

```python
Expand All @@ -135,10 +136,6 @@ class PairwiseJudgeColumnGenerator(ColumnGeneratorWithModelRegistry[PairwiseJudg
def get_generation_strategy() -> GenerationStrategy:
return GenerationStrategy.CELL_BY_CELL

def _validate(self) -> None:
self.get_model_config(self.config.model_alias)
self.get_model_config(self.config.judge_model_alias)

async def agenerate(self, data: dict) -> dict:
generator_model = self.get_model(self.config.model_alias)
judge_model = self.get_model(self.config.judge_model_alias)
Expand All @@ -163,6 +160,8 @@ class PairwiseJudgeColumnGenerator(ColumnGeneratorWithModelRegistry[PairwiseJudg
return data
```

If your config has no `model_alias` field at all (uncommon but valid), override `get_model_aliases()` to return whichever fields name your model dependencies β€” the default implementation reads `model_alias` via `getattr` and returns an empty list when it is absent, so it will not crash on configs without it.

## What the registry returns

`get_model(...)` returns a `ModelFacade`. Call the facade based on the modality your plugin needs:
Expand All @@ -181,7 +180,7 @@ Prefer implementing `agenerate(...)` for model-backed plugins. The base `generat

The model-aware bases mark the generator as LLM-bound, so the async scheduler treats the work like other model calls.

Plugin discovery treats column generator implementations that inherit from `ColumnGeneratorWithModelRegistry` as model-generated column types for startup model health checks. The standard health-check collection reads a primary `model_alias` field directly from the config. Additional alias fields should be registration-validated by the plugin implementation; their endpoints are not pinged by the standard startup health check.
Plugin discovery treats column generator implementations that inherit from `ColumnGeneratorWithModelRegistry` as model-generated column types for startup model health checks. The standard health-check collection calls `SingleColumnConfig.get_model_aliases()` on each column config and pings every alias it returns. The default implementation returns the column's primary `model_alias` (or an empty list for configs without one); configs with multiple model fields should override it so the startup check exercises every endpoint they depend on.

## Built-in patterns

Expand Down
19 changes: 19 additions & 0 deletions packages/data-designer-config/src/data_designer/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,25 @@ def side_effect_columns(self) -> list[str]:
indicates no side effect columns. Override in subclasses to specify side effects.
"""

def get_model_aliases(self) -> list[str]:
"""Return every model alias this column depends on.

The startup model health check uses this to decide which model endpoints to ping.
The default implementation returns the column's primary ``model_alias`` (if set),
which covers the built-in LLM, embedding, and image columns.

Override this method on configs that depend on more than one model β€” for example,
a plugin config with both a ``model_alias`` and a ``judge_model_alias`` should return
both so a typo or unreachable endpoint on the secondary alias surfaces at startup
rather than at first generation.

Returns:
List of model aliases this column depends on. Empty list indicates the column
does not call any model endpoints.
"""
alias: str | None = getattr(self, "model_alias", None)
return [alias] if alias else []


class ProcessorConfig(ConfigBase, ABC):
"""Abstract base class for all processor configuration types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,10 @@ def model_aliases(self) -> list[str]:
metadata = getattr(self.generator_function, "custom_column_metadata", {})
return metadata.get("model_aliases", [])

def get_model_aliases(self) -> list[str]:
"""Returns the decorator-declared aliases so the startup health check pings every endpoint."""
return self.model_aliases

@field_serializer("generator_function")
def serialize_generator_function(self, v: Any) -> str:
return getattr(v, "__name__", repr(v))
Expand Down
43 changes: 43 additions & 0 deletions packages/data-designer-config/tests/config/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,49 @@ def test_allow_resize_inherited_by_subclasses() -> None:
assert StubColumnConfig(name="test", allow_resize=True).allow_resize is True


def test_get_model_aliases_empty_when_no_model_alias_field() -> None:
"""Configs without a model_alias field return an empty list, not AttributeError."""
assert StubColumnConfig(name="test").get_model_aliases() == []


@pytest.mark.parametrize(
"config",
[
LLMTextColumnConfig(name="t", prompt=stub_prompt, model_alias=stub_model_alias),
LLMCodeColumnConfig(name="c", prompt=stub_prompt, code_lang=CodeLang.PYTHON, model_alias=stub_model_alias),
EmbeddingColumnConfig(name="e", target_column="text", model_alias=stub_model_alias),
ImageColumnConfig(name="i", prompt="Generate {{ x }}", model_alias=stub_model_alias),
],
ids=["llm-text", "llm-code", "embedding", "image"],
)
def test_get_model_aliases_returns_primary_alias_for_builtins(config: SingleColumnConfig) -> None:
"""Built-in model-backed configs return their primary model_alias by default."""
assert config.get_model_aliases() == [stub_model_alias]


def test_get_model_aliases_can_be_overridden_for_multi_model_plugins() -> None:
"""A plugin config with multiple model fields can override get_model_aliases()."""

class _PairwiseJudgeColumnConfig(SingleColumnConfig):
column_type: Literal["pairwise-judge-test"] = "pairwise-judge-test"
model_alias: str
judge_model_alias: str

@property
def required_columns(self) -> list[str]:
return []

@property
def side_effect_columns(self) -> list[str]:
return []

def get_model_aliases(self) -> list[str]:
return [self.model_alias, self.judge_model_alias]

config = _PairwiseJudgeColumnConfig(name="pj", model_alias="primary", judge_model_alias="judge")
assert config.get_model_aliases() == ["primary", "judge"]


@pytest.mark.parametrize(
("dtype", "raw_value", "expected_value", "expected_type"),
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pydantic import ValidationError

import data_designer.lazy_heavy_imports as lazy
from data_designer.config.column_configs import CustomColumnConfig
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType
from data_designer.config.config_builder import BuilderConfig
from data_designer.config.data_designer_config import DataDesignerConfig
Expand Down Expand Up @@ -1089,10 +1088,7 @@ def _merge_skipped_and_generated(
def _run_model_health_check_if_needed(self) -> None:
model_aliases: set[str] = set()
for config in self.single_column_configs:
if column_type_is_model_generated(config.column_type):
model_aliases.add(config.model_alias)
if isinstance(config, CustomColumnConfig) and config.model_aliases:
model_aliases.update(config.model_aliases)
model_aliases.update(config.get_model_aliases())

if not model_aliases:
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def decorated_generator(row: dict) -> dict:
assert config.required_columns == ["col1", "col2"]
assert config.side_effect_columns == ["extra"]
assert config.model_aliases == ["model-a"]
# get_model_aliases() opts the decorator-declared aliases into the startup health check
assert config.get_model_aliases() == ["model-a"]

# Serialization works
assert config.model_dump()["generator_function"] == "decorated_generator"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,61 @@ def test_dataset_builder_build_method_basic_flow(
assert result_path == stub_resource_provider.artifact_storage.final_dataset_path


def test_run_model_health_check_collects_aliases_from_get_model_aliases(
stub_resource_provider,
stub_model_configs,
) -> None:
"""The health check pings every alias returned by each config's get_model_aliases().

Regression test for #606: secondary aliases on multi-model plugin configs (returned via
get_model_aliases()) must be passed to run_health_check(), not just the primary
model_alias field.
"""
stub_resource_provider.model_registry.run_health_check = Mock()

@custom_column_generator(model_aliases=["custom-model-a", "custom-model-b"])
def gen_with_two_models(row: dict, generator_params, models) -> dict:
del generator_params, models
return row

config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
config_builder.add_column(
SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())
)
config_builder.add_column(LLMTextColumnConfig(name="builtin_llm_col", prompt="x", model_alias="builtin-model"))
config_builder.add_column(CustomColumnConfig(name="custom_col", generator_function=gen_with_two_models))

builder = DatasetBuilder(
data_designer_config=config_builder.build(),
resource_provider=stub_resource_provider,
)
builder._run_model_health_check_if_needed()

stub_resource_provider.model_registry.run_health_check.assert_called_once()
(called_aliases,), _ = stub_resource_provider.model_registry.run_health_check.call_args
assert set(called_aliases) == {"builtin-model", "custom-model-a", "custom-model-b"}


def test_run_model_health_check_skips_when_no_model_aliases(
stub_resource_provider,
stub_model_configs,
) -> None:
"""Configs with no model aliases (e.g. samplers only) skip the health check entirely."""
stub_resource_provider.model_registry.run_health_check = Mock()

config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
config_builder.add_column(
SamplerColumnConfig(name="seed_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())
)
builder = DatasetBuilder(
data_designer_config=config_builder.build(),
resource_provider=stub_resource_provider,
)
builder._run_model_health_check_if_needed()

stub_resource_provider.model_registry.run_health_check.assert_not_called()


@pytest.mark.parametrize(
"column_configs,expected_error",
[
Expand Down
Loading