Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensorrt_llm/_torch/disaggregation/transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def __init__(
self._page_table = self._transfer_worker.page_table
self._is_v2_manager = isinstance(kv_cache_manager, KVCacheManagerV2)

# Sticky role markers; flip True once any session opens, used to short-circuit
# per-iter tp_allgather when this transceiver never sends/receives.
self._ever_had_send_session: bool = False
self._ever_had_recv_session: bool = False

def _broadcast_instance_name(self) -> str:
if self._dist.rank == 0:
name = str(uuid.uuid4())
Expand Down Expand Up @@ -290,6 +295,7 @@ def _apply_aux(self, session, req: LlmRequest):

@nvtx_range("KvCacheTransceiverV2.respond_and_send_async")
def respond_and_send_async(self, req: LlmRequest):
self._ever_had_send_session = True
rid = get_unique_rid(req)
assert rid is not None
if rid not in self._send_sessions:
Expand Down Expand Up @@ -345,6 +351,7 @@ def request_and_receive_sync(self, req: LlmRequest):

@nvtx_range("KvCacheTransceiverV2.request_and_receive_async")
def request_and_receive_async(self, req: LlmRequest):
self._ever_had_recv_session = True
rid = get_unique_rid(req)
if rid in self._recv_sessions:
logger.warning(
Expand All @@ -360,6 +367,9 @@ def request_and_receive_async(self, req: LlmRequest):
def check_context_transfer_status(
self, at_least_request_num: Optional[int], mark_complete: bool = False
):
# Skip the tp_allgather in _ctx_consensus when this transceiver never sends (pure GEN role).
if not self._ever_had_send_session:
return [], []
block_all = at_least_request_num is None
wait_num = at_least_request_num if not block_all else 0

Expand Down Expand Up @@ -397,6 +407,9 @@ def check_context_transfer_status(
return completed, failed

def check_gen_transfer_status(self, at_least_request_num: Optional[int]):
# Skip the allgather in _gen_consensus when this transceiver never receives (pure CTX role).
if not self._ever_had_recv_session:
return [], []
block_all = at_least_request_num is None
wait_num = at_least_request_num if not block_all else 0

Expand Down
Loading