Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,50 @@ def generate(self, data: dict) -> dict:

return data

async def agenerate(self, data: dict) -> dict:
deserialized_record = deserialize_json_values(data)

multi_modal_context = None
if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
multi_modal_context = []
for context in self.config.multi_modal_context:
multi_modal_context.extend(context.get_contexts(deserialized_record))

response, trace = await self.model.agenerate(
prompt=self.prompt_renderer.render(
record=deserialized_record,
prompt_template=self.config.prompt,
prompt_type=PromptType.USER_PROMPT,
),
system_prompt=self.prompt_renderer.render(
record=deserialized_record,
prompt_template=self.config.system_prompt,
prompt_type=PromptType.SYSTEM_PROMPT,
),
parser=self.response_recipe.parse,
multi_modal_context=multi_modal_context,
tool_alias=self.config.tool_alias,
max_correction_steps=self.max_conversation_correction_steps,
max_conversation_restarts=self.max_conversation_restarts,
purpose=f"running generation for column '{self.config.name}'",
)

serialized_output = self.response_recipe.serialize_output(response)
data[self.config.name] = self._process_serialized_output(serialized_output)

effective_trace_type = self.config.with_trace

if effective_trace_type == TraceType.ALL_MESSAGES:
data[self.config.name + TRACE_COLUMN_POSTFIX] = [message.to_dict() for message in trace]
elif effective_trace_type == TraceType.LAST_MESSAGE:
last_assistant = next((m for m in reversed(trace) if m.role == "assistant"), None)
data[self.config.name + TRACE_COLUMN_POSTFIX] = [last_assistant.to_dict()] if last_assistant else []

if self.config.extract_reasoning_content:
data[self.config.name + REASONING_CONTENT_COLUMN_POSTFIX] = self._extract_reasoning_content(trace)

return data

def _extract_reasoning_content(self, trace: list) -> str | None:
"""Extract reasoning_content from the final assistant message in the trace.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import functools
import logging
import os
import time
import uuid
from pathlib import Path
Expand All @@ -31,6 +32,7 @@
from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
from data_designer.engine.dataset_builders.utils.async_concurrency import AsyncConcurrentExecutor
from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
Expand All @@ -50,6 +52,11 @@

logger = logging.getLogger(__name__)

DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "0") == "1"

if DATA_DESIGNER_ASYNC_ENGINE:
logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async concurrency")

_CLIENT_VERSION: str = get_library_version()


Expand Down Expand Up @@ -199,7 +206,11 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
max_workers = self._resource_provider.run_config.non_inference_max_parallel_workers
if isinstance(generator, ColumnGeneratorWithModel):
max_workers = generator.inference_parameters.max_parallel_requests
self._fan_out_with_threads(generator, max_workers=max_workers)
if DATA_DESIGNER_ASYNC_ENGINE:
logger.info("⚡ Using async engine for concurrent execution")
self._fan_out_with_async(generator, max_workers=max_workers)
else:
self._fan_out_with_threads(generator, max_workers=max_workers)

def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
Expand All @@ -226,6 +237,41 @@ def _run_mcp_tool_check_if_needed(self) -> None:
raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.")
self._resource_provider.mcp_registry.run_health_check(tool_aliases)

def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
raise DatasetGenerationError(
f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} "
"generator so concurrency through async is not supported."
)

progress_tracker = ProgressTracker(
total_records=self.batch_manager.num_records_batch,
label=f"{generator.config.column_type} column '{generator.config.name}'",
)
progress_tracker.log_start(max_workers)

settings = self._resource_provider.run_config
executor = AsyncConcurrentExecutor(
max_workers=max_workers,
column_name=generator.config.name,
result_callback=self._make_result_callback(progress_tracker),
error_callback=self._make_error_callback(progress_tracker),
shutdown_error_rate=settings.shutdown_error_rate,
shutdown_error_window=settings.shutdown_error_window,
disable_early_shutdown=settings.disable_early_shutdown,
)

work_items = [
(generator.agenerate(record), {"index": i}) for i, record in self.batch_manager.iter_current_batch()
]
executor.run(work_items)

progress_tracker.log_final()

if len(self._records_to_drop) > 0:
self.batch_manager.drop_records(self._records_to_drop)
self._records_to_drop.clear()

def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
raise DatasetGenerationError(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import asyncio
import json
import logging
import threading
from collections.abc import Coroutine
from dataclasses import dataclass
from typing import Any, Generic, TypeVar

from data_designer.engine.dataset_builders.utils.concurrency import (
CallbackWithContext,
ErrorCallbackWithContext,
ExecutorResults,
)
from data_designer.engine.errors import DataDesignerRuntimeError

logger = logging.getLogger(__name__)

T = TypeVar("T")


@dataclass(frozen=True, slots=True)
class Success(Generic[T]):
index: int
value: T


@dataclass(frozen=True, slots=True)
class Failure:
index: int
error: Exception


TaskResult = Success[T] | Failure

_loop: asyncio.AbstractEventLoop | None = None
_thread: threading.Thread | None = None
_lock = threading.Lock()


def _run_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()


def _ensure_async_engine_loop() -> asyncio.AbstractEventLoop:
"""Get or create a persistent event loop for async engine work.

A single event loop is shared across all AsyncConcurrentExecutor instances
to avoid breaking libraries (like LiteLLM) that bind internal async state
to a specific event loop.
"""
global _loop, _thread
with _lock:
if _loop is None or not _loop.is_running():
_loop = asyncio.new_event_loop()
_thread = threading.Thread(target=_run_loop, args=(_loop,), daemon=True, name="AsyncEngine-EventLoop")
_thread.start()
return _loop


class AsyncConcurrentExecutor:
"""Async equivalent of ConcurrentThreadExecutor.

Executes a batch of coroutines with bounded concurrency, error rate
monitoring, and early shutdown semantics. Callers remain synchronous —
the ``run()`` method submits work to a persistent background event loop.

No locks are needed because asyncio tasks run cooperatively on a
single thread — mutations to ``_results`` are always sequential.
"""

def __init__(
self,
*,
max_workers: int,
column_name: str,
result_callback: CallbackWithContext | None = None,
error_callback: ErrorCallbackWithContext | None = None,
shutdown_error_rate: float = 0.50,
shutdown_error_window: int = 10,
disable_early_shutdown: bool = False,
) -> None:
self._column_name = column_name
self._max_workers = max_workers
self._result_callback = result_callback
self._error_callback = error_callback
self._shutdown_error_rate = shutdown_error_rate
self._shutdown_window_size = shutdown_error_window
self._disable_early_shutdown = disable_early_shutdown
self._results = ExecutorResults(failure_threshold=shutdown_error_rate)

@property
def results(self) -> ExecutorResults:
return self._results

@property
def max_workers(self) -> int:
return self._max_workers

@property
def shutdown_error_rate(self) -> float:
return self._shutdown_error_rate

@property
def shutdown_window_size(self) -> int:
return self._shutdown_window_size

def run(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None:
"""Execute all work items concurrently. Callers remain synchronous."""
logger.debug(
f"AsyncConcurrentExecutor: launching {len(work_items)} tasks "
f"with max_workers={self._max_workers} for column '{self._column_name}'"
)
loop = _ensure_async_engine_loop()
future = asyncio.run_coroutine_threadsafe(self._run_all(work_items), loop)
future.result()

async def _run_all(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None:
self._semaphore = asyncio.Semaphore(self._max_workers)
self._shutdown_event = asyncio.Event()

async with asyncio.TaskGroup() as tg:
for i, (coro, context) in enumerate(work_items):
tg.create_task(self._run_task(i, coro, context))

if not self._disable_early_shutdown and self._results.early_shutdown:
self._raise_task_error()

async def _run_task(self, index: int, coro: Coroutine[Any, Any, Any], context: dict | None) -> None:
if self._shutdown_event.is_set():
return

async with self._semaphore:
if self._shutdown_event.is_set():
return

try:
result = await coro
self._results.completed_count += 1
self._results.success_count += 1
if self._result_callback is not None:
self._result_callback(result, context=context)
except Exception as err:
self._results.completed_count += 1
self._results.error_trap.handle_error(err)
if not self._disable_early_shutdown and self._results.is_error_rate_exceeded(
self._shutdown_window_size
):
if not self._results.early_shutdown:
self._results.early_shutdown = True
self._shutdown_event.set()
if self._error_callback is not None:
self._error_callback(err, context=context)

def _raise_task_error(self) -> None:
raise DataDesignerRuntimeError(
"\n".join(
[
" |-- Data generation was terminated early due to error rate exceeding threshold.",
f" |-- The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}",
]
)
)
Original file line number Diff line number Diff line change
@@ -1,2 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import os
from pathlib import Path

_ASYNC_ENGINE_ENV_VAR = "DATA_DESIGNER_ASYNC_ENGINE"
_TRUTHY_ENV_VALUES = {"1", "true", "yes"}


def _is_async_engine_enabled() -> bool:
return os.getenv(_ASYNC_ENGINE_ENV_VAR, "").lower() in _TRUTHY_ENV_VALUES


def _redirect_to_models_v2() -> None:
models_v2_path = Path(__file__).resolve().parent.parent / "models_v2"
# Set DATA_DESIGNER_ASYNC_ENGINE before importing this package for it to take effect.
global __path__
__path__ = [str(models_v2_path)]
if __spec__ is not None:
__spec__.submodule_search_locations = [str(models_v2_path)]


if __name__ == "data_designer.engine.models" and _is_async_engine_enabled():
_redirect_to_models_v2()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
Loading
Loading