Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,15 @@ def _validate_subcategory_parents(self) -> Self:
if col.column_type != "sampler" or col.sampler_type != SamplerType.SUBCATEGORY:
continue
parent = by_name.get(col.params.category)
if parent is not None and parent.column_type != "sampler":
if parent is not None and (parent.column_type != "sampler" or parent.sampler_type != SamplerType.CATEGORY):
if parent.column_type == "sampler":
parent_sampler_type = getattr(parent.sampler_type, "value", parent.sampler_type)
parent_type = f"sampler column with sampler_type='{parent_sampler_type}'"
else:
parent_type = f"'{parent.column_type}' column"
raise ValueError(
f"Subcategory column '{col.name}' has parent '{parent.name}', which is a "
f"'{parent.column_type}' column. Subcategory parents must be sampler columns "
f"Subcategory column '{col.name}' has parent '{parent.name}', which is a {parent_type}. "
f"Subcategory parents must be sampler columns "
f"with sampler_type='category'."
)
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,31 @@ def test_subcategory_parent_as_category_sampler_is_valid() -> None:
assert len(config.columns) == 2


def test_subcategory_parent_must_be_a_category_sampler() -> None:
with pytest.raises(ValueError, match=r"sampler_type='uniform'.*sampler_type='category'"):
DataDesignerConfig.model_validate(
{
"columns": [
{
"name": "package_type",
"column_type": "sampler",
"sampler_type": "uniform",
"params": {"low": 0, "high": 1},
},
{
"name": "ski_category",
"column_type": "sampler",
"sampler_type": "subcategory",
"params": {
"category": "package_type",
"values": {"basic": ["a"], "premium": ["b"]},
},
},
]
}
)


def test_subcategory_parent_missing_defers_to_schema_validator() -> None:
config = DataDesignerConfig.model_validate(
{
Expand Down
Loading