Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/concepts/custom_columns.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,40 @@ This gives you direct access to all `ModelFacade` capabilities: custom parsers,
| `generator_function` | Callable | Yes | Decorated function |
| `generation_strategy` | GenerationStrategy | No | `CELL_BY_CELL` or `FULL_COLUMN` |
| `generator_params` | BaseModel | No | Typed params passed to function |
| `allow_resize` | bool | No | Allow 1:N or N:1 generation. Requires `FULL_COLUMN` strategy |

### Resizing (1:N and N:1)

With `full_column` strategy, you can produce more or fewer records than the input using `allow_resize=True`:

```python
@dd.custom_column_generator(
required_columns=["topic"],
side_effect_columns=["variation_id"],
)
def expand_topics(df: pd.DataFrame, params: None, models: dict) -> pd.DataFrame:
rows = []
for _, row in df.iterrows():
for i in range(3): # Generate 3 variations per input
rows.append({
"topic": row["topic"],
"question": f"Question {i+1} about {row['topic']}",
"variation_id": i,
})
return pd.DataFrame(rows)

dd.CustomColumnConfig(
name="question",
generator_function=expand_topics,
generation_strategy=dd.GenerationStrategy.FULL_COLUMN,
allow_resize=True,
)
```

Use cases:

- **Expansion (1:N)**: Generate multiple variations per input
- **Retraction (N:1)**: Filter, aggregate, or deduplicate records

## Multi-Turn Example

Expand Down
98 changes: 98 additions & 0 deletions example_allow_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Example: Using allow_resize for 1:N expansion and N:1 retraction."""

from __future__ import annotations

import data_designer.config as dd
from data_designer.interface import DataDesigner
from data_designer.lazy_heavy_imports import pd


@dd.custom_column_generator(required_columns=["topic"], side_effect_columns=["variation_id"])
def expand_to_questions(df: pd.DataFrame) -> pd.DataFrame:
"""Generate 3 questions per topic (1:N expansion)."""
rows = []
for _, row in df.iterrows():
for i in range(3):
rows.append(
{
"topic": row["topic"],
"question": f"Question {i + 1} about {row['topic']}?",
"variation_id": i,
}
)
return pd.DataFrame(rows)


@dd.custom_column_generator(required_columns=["topic", "score"])
def filter_high_scores(df: pd.DataFrame) -> pd.DataFrame:
"""Keep only records with score > 0.5 (N:1 retraction)."""
filtered = df[df["score"] > 0.5].copy()
filtered["status"] = "passed"
return filtered


def run_expansion_example() -> None:
"""3 topics -> 9 questions."""
data_designer = DataDesigner()
config_builder = dd.DataDesignerConfigBuilder()

config_builder.add_column(
dd.SamplerColumnConfig(
name="topic",
sampler_type=dd.SamplerType.CATEGORY,
params=dd.CategorySamplerParams(values=["Python", "ML", "Data"]),
)
)
config_builder.add_column(
dd.CustomColumnConfig(
name="question",
generator_function=expand_to_questions,
generation_strategy=dd.GenerationStrategy.FULL_COLUMN,
allow_resize=True,
)
)

preview = data_designer.preview(config_builder=config_builder, num_records=3)
print(f"Expansion: 3 -> {len(preview.dataset)} records")
print(preview.dataset.to_string())


def run_retraction_example() -> None:
"""10 records -> ~5 (filtered)."""
data_designer = DataDesigner()
config_builder = dd.DataDesignerConfigBuilder()

config_builder.add_column(
dd.SamplerColumnConfig(
name="topic",
sampler_type=dd.SamplerType.CATEGORY,
params=dd.CategorySamplerParams(values=["A", "B", "C", "D", "E"]),
)
)
config_builder.add_column(
dd.SamplerColumnConfig(
name="score",
sampler_type=dd.SamplerType.UNIFORM,
params=dd.UniformSamplerParams(low=0.0, high=1.0),
)
)
config_builder.add_column(
dd.CustomColumnConfig(
name="status",
generator_function=filter_high_scores,
generation_strategy=dd.GenerationStrategy.FULL_COLUMN,
allow_resize=True,
)
)

preview = data_designer.preview(config_builder=config_builder, num_records=10)
print(f"Retraction: 10 -> {len(preview.dataset)} records")
print(preview.dataset.to_string())


if __name__ == "__main__":
run_expansion_example()
# run_retraction_example()
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,14 @@ class CustomColumnConfig(SingleColumnConfig):
default=None,
description="Optional typed configuration object passed as second argument to generator function",
)
allow_resize: bool = Field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should elevate this to a property on the base column config (default is False), which you can override in custom columns and plugins.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that was the initial solution, then I ended up doing a mixin instead.
Thought it was a bit opaque for plugins specifically, that they developer to find out about a specific attribute/property 🤔 But it makes things simpler I suppose?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pattern we already use for custom emojis. Also the required_columns and side_effect_columns (these ones have to be set, though).

default=False,
description=(
"If True, allows the generator to produce a different number of records than the input. "
"Use for 1:N (expansion) or N:1 (retraction) generation patterns. "
"Only applicable when generation_strategy is 'full_column'."
),
)
column_type: Literal["custom"] = "custom"

@field_validator("generator_function")
Expand Down Expand Up @@ -560,3 +568,12 @@ def validate_generator_function(self) -> Self:
f"Expected a function decorated with @custom_column_generator."
)
return self

@model_validator(mode="after")
def validate_allow_resize_requires_full_column(self) -> Self:
if self.allow_resize and self.generation_strategy != GenerationStrategy.FULL_COLUMN:
raise InvalidConfigError(
f"🛑 `allow_resize=True` requires `generation_strategy='full_column'` for column '{self.name}'. "
f"Cell-by-cell strategy processes one row at a time and cannot change record count."
)
return self
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,5 @@ def log_pre_generation(self) -> None:
logger.info(f" |-- model_aliases: {self.config.model_aliases}")
if self.config.generator_params:
logger.info(f" |-- generator_params: {self.config.generator_params}")
if self.config.allow_resize:
logger.info(f" |-- allow_resize: {self.config.allow_resize}")
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,23 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
self._fan_out_with_threads(generator, max_workers=max_workers)

def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
original_count = self.batch_manager.num_records_in_buffer
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
self.batch_manager.update_records(df.to_dict(orient="records"))
allow_resize = getattr(generator.config, "allow_resize", False)
new_count = len(df)

if allow_resize and new_count != original_count:
if new_count == 0:
logger.warning(
f"⚠️ Column '{generator.config.name}' reduced batch to 0 records. This batch will be skipped."
)
else:
logger.info(
f"📊 Column '{generator.config.name}' resized batch: {original_count} -> {new_count} records. "
f"Subsequent columns will operate on the new record count."
)

self.batch_manager.update_records(df.to_dict(orient="records"), allow_resize=allow_resize)

def _run_model_health_check_if_needed(self) -> None:
model_aliases: set[str] = set()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, artifact_storage: ArtifactStorage):
self._current_batch_number = 0
self._num_records_list: list[int] | None = None
self._buffer_size: int | None = None
self._actual_num_records: int = 0
self.artifact_storage = artifact_storage

@property
Expand Down Expand Up @@ -83,11 +84,13 @@ def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Pat
raise DatasetBatchManagementError("🛑 All batches have been processed.")

if self.write() is not None:
self._actual_num_records += len(self._buffer)
final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number)

self.artifact_storage.write_metadata(
{
"target_num_records": sum(self.num_records_list),
"actual_num_records": self._actual_num_records,
"total_num_batches": self.num_batches,
"buffer_size": self._buffer_size,
"schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
Expand Down Expand Up @@ -141,6 +144,7 @@ def iter_current_batch(self) -> Iterator[tuple[int, dict]]:
def reset(self, delete_files: bool = False) -> None:
self._current_batch_number = 0
self._buffer: list[dict] = []
self._actual_num_records = 0
if delete_files:
for dir_path in [
self.artifact_storage.final_dataset_path,
Expand Down Expand Up @@ -191,8 +195,16 @@ def update_record(self, index: int, record: dict) -> None:
raise IndexError(f"🛑 Index {index} is out of bounds for buffer of size {len(self._buffer)}.")
self._buffer[index] = record

def update_records(self, records: list[dict]) -> None:
if len(records) != len(self._buffer):
def update_records(self, records: list[dict], *, allow_resize: bool = False) -> None:
"""Update all records in the buffer.

Args:
records: New records to replace the buffer.
allow_resize: If True, allows the number of records to differ from the current
buffer size. Use for 1:N (expansion) or N:1 (retraction) generation patterns.
Defaults to False for strict 1:1 mapping.
"""
if not allow_resize and len(records) != len(self._buffer):
raise DatasetBatchManagementError(
f"🛑 Number of records to update ({len(records)}) must match "
f"the number of records in the buffer ({len(self._buffer)})."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy
from data_designer.config.custom_column import custom_column_generator
from data_designer.config.errors import InvalidConfigError
from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator
from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError
from data_designer.engine.resources.resource_provider import ResourceProvider
Expand Down Expand Up @@ -113,6 +114,31 @@ def test_config_validation_non_callable() -> None:
CustomColumnConfig(name="test", generator_function="not_a_function")


def test_config_validation_allow_resize_requires_full_column() -> None:
"""Test that allow_resize=True requires generation_strategy=FULL_COLUMN."""

@custom_column_generator()
def dummy_fn(row: dict) -> dict:
return row

with pytest.raises(InvalidConfigError, match="allow_resize=True.*requires.*full_column"):
CustomColumnConfig(
name="test",
generator_function=dummy_fn,
allow_resize=True,
generation_strategy=GenerationStrategy.CELL_BY_CELL,
)

# Should work with FULL_COLUMN
config = CustomColumnConfig(
name="test",
generator_function=dummy_fn,
allow_resize=True,
generation_strategy=GenerationStrategy.FULL_COLUMN,
)
assert config.allow_resize is True


# Cell-by-cell generation tests


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,61 @@ def test_update_records_wrong_length(stub_batch_manager_with_data):
stub_batch_manager_with_data.update_records(wrong_length_records)


def test_update_records_allow_resize_expansion(stub_batch_manager_with_data):
"""Test that allow_resize=True permits expanding the record count (1:N)."""
records = [{"id": i, "name": f"test{i}"} for i in range(3)]
stub_batch_manager_with_data.add_records(records)

# Expand from 3 to 6 records
expanded_records = [{"id": i, "name": f"expanded{i}"} for i in range(6)]
stub_batch_manager_with_data.update_records(expanded_records, allow_resize=True)

assert stub_batch_manager_with_data.num_records_in_buffer == 6
assert stub_batch_manager_with_data._buffer == expanded_records


def test_update_records_allow_resize_retraction(stub_batch_manager_with_data):
"""Test that allow_resize=True permits reducing the record count (N:1)."""
records = [{"id": i, "name": f"test{i}"} for i in range(3)]
stub_batch_manager_with_data.add_records(records)

# Retract from 3 to 1 record
retracted_records = [{"id": 0, "name": "aggregated"}]
stub_batch_manager_with_data.update_records(retracted_records, allow_resize=True)

assert stub_batch_manager_with_data.num_records_in_buffer == 1
assert stub_batch_manager_with_data._buffer == retracted_records


def test_update_records_allow_resize_to_empty(stub_batch_manager_with_data):
"""Test that allow_resize=True permits reducing to zero records."""
records = [{"id": i, "name": f"test{i}"} for i in range(3)]
stub_batch_manager_with_data.add_records(records)

stub_batch_manager_with_data.update_records([], allow_resize=True)

assert stub_batch_manager_with_data.num_records_in_buffer == 0
assert stub_batch_manager_with_data.buffer_is_empty


def test_actual_num_records_tracks_expansion(stub_batch_manager_with_data):
"""Test that actual_num_records correctly tracks when buffer is resized."""
# Add 3 records, then expand to 6
records = [{"id": i} for i in range(3)]
stub_batch_manager_with_data.add_records(records)
expanded = [{"id": i} for i in range(6)]
stub_batch_manager_with_data.update_records(expanded, allow_resize=True)

# Finish batch and check metadata
stub_batch_manager_with_data.finish_batch()

with open(stub_batch_manager_with_data.artifact_storage.metadata_file_path) as f:
metadata = json.load(f)

assert metadata["target_num_records"] == 10 # original target
assert metadata["actual_num_records"] == 6 # actual expanded count


# Test write method
def test_write_empty_buffer(stub_batch_manager_with_data):
result = stub_batch_manager_with_data.write()
Expand Down Expand Up @@ -271,6 +326,7 @@ def test_finish_batch_metadata_content(stub_batch_manager_with_data):
metadata = json.load(f)

assert metadata["target_num_records"] == 10
assert metadata["actual_num_records"] == 3 # actual records written in this batch
assert metadata["total_num_batches"] == 4
assert metadata["buffer_size"] == 3
assert metadata["num_completed_batches"] == 1
Expand Down