Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# Disable UCC to WAR allgather issue before NGC PyTorch 25.12 upgrade.
os.environ["OMPI_MCA_coll_ucc_enable"] = "0"


def _add_trt_llm_dll_directory():
import platform
on_windows = platform.system() == "Windows"
Expand Down
227 changes: 197 additions & 30 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
SampleStateTensors)
from .scheduler import (RequestScheduler, ScheduledRequests,
SerializableSchedulerOutput)
# New scheduler architecture (conditionally used)
from .request_queue import RequestQueue
from .scheduler.unified import UnifiedSPMDScheduler
from .scheduler.types import SchedulerConfig
from .scheduler.local_scheduler import PyCapacityScheduler, PyMicroBatchScheduler

# Environment variable to specify iteration ranges for profiling start/stop.
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
Expand Down Expand Up @@ -293,18 +298,28 @@ def on_detected():

# request fetcher initialization
self._set_global_steady_clock_offset()
self.executor_request_queue = ExecutorRequestQueue(
dist=self.dist,
enable_attention_dp=self.enable_attention_dp,
max_batch_size=max_batch_size,
max_beam_width=self.max_beam_width,
max_num_active_requests=self.max_num_active_requests,
enable_iter_perf_stats=self.enable_iter_perf_stats,
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
hang_detector=self.hang_detector,
)
self.executor_request_queue.set_exclude_last_generation_logits(
self.disable_overlap_scheduler, self.dist.pp_size)

# Check if using unified scheduler (new architecture)
self._use_unified_scheduler = os.environ.get('TRTLLM_USE_UNIFIED_SCHEDULER', '0') == '1'

if self._use_unified_scheduler:
# New architecture: use RequestQueue + UnifiedSPMDScheduler
self._init_unified_scheduler(max_batch_size)
else:
# Original architecture: use ExecutorRequestQueue
self.executor_request_queue = ExecutorRequestQueue(
dist=self.dist,
enable_attention_dp=self.enable_attention_dp,
max_batch_size=max_batch_size,
max_beam_width=self.max_beam_width,
max_num_active_requests=self.max_num_active_requests,
enable_iter_perf_stats=self.enable_iter_perf_stats,
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
hang_detector=self.hang_detector,
)
self.executor_request_queue.set_exclude_last_generation_logits(
self.disable_overlap_scheduler, self.dist.pp_size)
self.unified_scheduler = None
self.control_request_barrier = threading.Event()
self.control_action_done = threading.Event()

Expand Down Expand Up @@ -345,6 +360,62 @@ def on_detected():
if start_worker:
self.start_worker()

def _init_unified_scheduler(self, max_batch_size: int):
"""Initialize the unified scheduler (new architecture)."""
from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy
from .request_fetcher import RequestFetcher, FetcherConfig
from .request_utils import RequestBroadcaster

# Create new RequestQueue (replaces ExecutorRequestQueue)
self.request_queue = RequestQueue(
initial_request_id=max_batch_size,
enable_iter_perf_stats=self.enable_iter_perf_stats,
rank=self.dist.rank,
)
# Compatibility: some code paths still reference executor_request_queue
self.executor_request_queue = self.request_queue

# Create fetcher and broadcaster
fetcher_config = FetcherConfig(
max_batch_size=max_batch_size,
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
)
self.request_fetcher = RequestFetcher(self.request_queue, fetcher_config)
self.request_broadcaster = RequestBroadcaster(self.dist)

# Create scheduler config
scheduler_config = SchedulerConfig(
max_batch_size=max_batch_size,
max_num_tokens=self.model_engine.llm_args.max_num_tokens,
max_num_active_requests=self.max_num_active_requests,
enable_attention_dp=self.enable_attention_dp,
)

# Create capacity and micro-batch schedulers
capacity_scheduler = PyCapacityScheduler(
capacity_scheduler_policy=getattr(
self.model_engine.llm_args, 'capacity_scheduler_policy',
CapacitySchedulerPolicy.MAX_UTILIZATION
),
max_num_batched_tokens=self.model_engine.llm_args.max_num_tokens,
max_batch_size=max_batch_size,
)
micro_batch_scheduler = PyMicroBatchScheduler()

# Create unified scheduler (no fetch - executor handles it using scheduler's RequestFetcher)
self.unified_scheduler = UnifiedSPMDScheduler(
dist=self.dist,
config=scheduler_config,
capacity_scheduler=capacity_scheduler,
micro_batch_scheduler=micro_batch_scheduler,
check_disagg_transfer_cb=self._check_disagg_gen_transfer_status if self.kv_cache_transceiver else None,
prepare_disagg_gen_init_cb=self._prepare_disagg_gen_init if self.kv_cache_transceiver else None,
add_dummy_request_cb=self._add_dummy_request if self.enable_attention_dp else None,
validate_request_cb=self._respond_if_invalid,
)

logger.info("Using unified scheduler (TRTLLM_USE_UNIFIED_SCHEDULER=1)")

def _maybe_init_kv_connector_manager(self):
if self.kv_connector_manager is not None:
if self.kv_cache_transceiver is not None:
Expand Down Expand Up @@ -564,8 +635,12 @@ def set_gather_responses(self, gather_all_responses):

@property
def should_stop_processing(self):
return self.is_shutdown and len(self.active_requests) == 0 and \
self.executor_request_queue.get_waiting_queue_size() == 0
if self._use_unified_scheduler:
return self.is_shutdown and len(self.active_requests) == 0 and \
self.unified_scheduler.get_waiting_queue_size() == 0
else:
return self.is_shutdown and len(self.active_requests) == 0 and \
self.executor_request_queue.get_waiting_queue_size() == 0

@contextmanager
def _profiler(self):
Expand Down Expand Up @@ -1342,6 +1417,92 @@ def _prepare_and_schedule_batch(self):
f'{len(scheduled_batch.generation_requests)} generation requests')
return scheduled_batch, iter_stats

def _prepare_and_schedule_batch_v2(self):
"""
New scheduling path using UnifiedSPMDScheduler.

Uses RequestFetcher + RequestBroadcaster + request_utils from pyexecutor.
"""
from .request_utils import (
collect_py_objects,
attach_py_objects,
)

# Calculate idle state
if self.enable_attention_dp:
num_active_tokens = sum(req.py_orig_prompt_len for req in self.active_requests)
responses_list = self.dist.tp_allgather([len(self.active_requests), num_active_tokens])
total_num_active = sum(r[0] for r in responses_list)
else:
total_num_active = len(self.active_requests)

idle = (total_num_active == 0) and self.unified_scheduler.get_waiting_queue_size() == 0

# Fetch from queue (rank 0 only)
new_requests = []
if self.dist.rank == 0:
new_requests = self.request_fetcher.fetch(idle)

# Broadcast to all ranks
if self.dist.rank == 0:
py_request_objects = collect_py_objects(new_requests)
else:
py_request_objects = None

new_requests, py_request_objects = self.request_broadcaster.broadcast(
new_requests, py_request_objects
)

# Attach Python objects (for non-rank-0)
if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp or
self.dist.cp_size > 1) and self.dist.rank > 0:
attach_py_objects(new_requests, py_request_objects)

if self.should_stop_processing:
return None, None

# Prepare iter_stats
iter_stats = None
if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(len(new_requests), 0)

# Call scheduler with validated requests
result = self.unified_scheduler.prepare_and_schedule_batch(
active_requests=self.active_requests,
inflight_request_ids=self.inflight_request_ids,
new_requests=new_requests,
iter_stats=iter_stats,
drafter=self.drafter,
max_batch_size=self.max_batch_size,
max_num_tokens=self.model_engine.llm_args.max_num_tokens,
)

# Update shutdown state from scheduler
self.is_shutdown = self.unified_scheduler.is_shutdown
if self.should_stop_processing:
return None, None

# result is None means shutdown
if result is None:
return None, None

# Update executor state from result
scheduled_batch = result.scheduled_batch

# Update drafter-related state
if self.drafter is not None:
self.use_spec_decode = result.use_spec_decode
self.max_total_draft_tokens = result.max_draft_tokens
self.model_engine.enable_spec_decode = result.use_spec_decode

self.num_scheduled_requests = scheduled_batch.batch_size
logger.debug(
f'has {len(self.active_requests)} active_requests, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
f'{len(scheduled_batch.generation_requests)} generation requests')

return scheduled_batch, result.iter_stats

def _kv_connector_start_batch(self, scheduled_batch):
if self.kv_connector_manager:
self.kv_connector_manager.take_scheduled_requests_pending_load(
Expand Down Expand Up @@ -1382,7 +1543,10 @@ def _executor_loop(self):
if self.enable_iter_perf_stats:
iter_start_time = time.time()

scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
if self._use_unified_scheduler:
scheduled_batch, iter_stats = self._prepare_and_schedule_batch_v2()
else:
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
self._handle_control_request()

if scheduled_batch is None:
Expand Down Expand Up @@ -1587,7 +1751,10 @@ def _executor_loop_overlap(self):
if self.enable_iter_perf_stats:
iter_start_time = time.time()

scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
if self._use_unified_scheduler:
scheduled_batch, iter_stats = self._prepare_and_schedule_batch_v2()
else:
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
self._handle_control_request()

if scheduled_batch is None:
Expand All @@ -1597,7 +1764,8 @@ def _executor_loop_overlap(self):
# to ensure consistent batch sizes for accurate performance measurement.
if not self.is_warmup and not can_forward:
if self.enable_attention_dp:
local_can_forward = self.executor_request_queue.num_fetch_requests + \
num_fetch = self.unified_scheduler.num_fetch_requests if self._use_unified_scheduler else self.executor_request_queue.num_fetch_requests
local_can_forward = num_fetch + \
len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size
all_can_forward = self.dist.tp_allgather(
local_can_forward)
Expand Down Expand Up @@ -1926,20 +2094,19 @@ def _validate_request(self, request: LlmRequest):
self.model_engine.model.lm_head.num_embeddings):
raise ValueError("Token ID out of range")

def _fetch_and_activate_new_requests(self) -> List[LlmRequest]:
def _respond_if_invalid(self, request) -> bool:
"""Immediately fail invalid request.

def _respond_if_invalid(request: LlmRequest) -> bool:
"""Immediately fail invalid request.
Return True if invalid request was encountered and handled.
"""
try:
self._validate_request(request)
return False
except Exception as e:
self._handle_errors(str(e), requests=[request])
return True

Return True if invalid request was encountered and
handled.
"""
try:
self._validate_request(request)
return False
except Exception as e:
self._handle_errors(str(e), requests=[request])
return True
def _fetch_and_activate_new_requests(self) -> List[LlmRequest]:

new_requests_cur_rank = self.executor_request_queue.fetch_new_requests(
self.active_requests)
Expand All @@ -1949,7 +2116,7 @@ def _respond_if_invalid(request: LlmRequest) -> bool:

validated_requests = [
request for request in new_requests_cur_rank
if not _respond_if_invalid(request)
if not self._respond_if_invalid(request)
]

self.active_requests.extend(validated_requests)
Expand Down
Loading
Loading