Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tensorrt_llm/_torch/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .gpu_memory_backend import GMSBackend, GPUMemoryBackend

__all__ = ["GPUMemoryBackend", "GMSBackend"]
581 changes: 581 additions & 0 deletions tensorrt_llm/_torch/memory/gpu_memory_backend.py

Large diffs are not rendered by default.

48 changes: 37 additions & 11 deletions tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,33 @@ def model_name(self) -> Optional[str]:
def query_timeout_s(self) -> Optional[int]:
return self._query_timeout_s

@property
def p2p_succeeded(self) -> bool:
"""Whether the last load_weights() call used P2P transfer.
def is_weights_preloaded(self) -> bool:
"""Whether the last :meth:`load_weights` call wired weights directly into the model.

Reports the result of the most recent ``load_weights()`` invocation
on this loader instance. ``ModelLoader`` consults this signal to
decide whether to run the standard weight-mapping pipeline:

- ``True``: MX P2P transfer succeeded; weights already live in
model parameter buffers via direct writes from the upstream
``MxLiveWeightLoader``. The mapping pipeline is skipped for
all parameters covered by P2P.
- ``False``: either P2P was never attempted (no MX server URL,
no model reference, library missing) or it failed and we
fell back to disk; weights still need to flow through
``model.load_weights(...)`` via the standard mapper.

Note this is a per-loader-instance flag, not a global one. The
flag is reset to ``False`` at the start of each ``load_weights``
call, so the value is only meaningful immediately after a
successful call.

``True`` means weights are already in model parameter buffers
and the standard weight-mapping pipeline should be skipped
for those parameters.
Returns:
``True`` iff the last ``load_weights`` populated the model
via P2P; ``False`` before any call and on any fallback path.
"""
return self._p2p_succeeded

def is_weights_preloaded(self) -> bool:
"""Whether the last MX load wrote weights directly into the model."""
return self._p2p_succeeded

def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]:
"""Load weights, preferring MX P2P transfer when available.

Expand Down Expand Up @@ -404,7 +417,20 @@ def publish_as_source(
def post_load_publish(
self, model, *, checkpoint_dir: str, weights_preloaded: bool = False
) -> None:
"""Publish only workers that loaded locally, not MX P2P receivers."""
"""Publish locally loaded weights as an MX source when appropriate.

Args:
model: The loaded model whose parameters should be published for
future MX P2P receivers.
checkpoint_dir: Checkpoint directory used as a fallback model
identity when no explicit MX model name is configured.
weights_preloaded: Whether this worker already received weights
through MX P2P. When true, this worker is an MX receiver and
should not republish the same weights as a source.

Returns:
None.
"""
if weights_preloaded:
return
self.publish_as_source(model, checkpoint_dir=checkpoint_dir)
Expand Down
89 changes: 84 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,17 +1558,96 @@ def _set_up_spec_metadata(
max_seq_len=self.max_seq_len)
return self.spec_metadata

def __del__(self) -> None:
def cleanup(self) -> None:
"""Release resources owned by this model engine.

Tears down, in order:

1. The optional ``ModelLoader`` (which in turn releases any
GMS client; see :meth:`ModelLoader.cleanup`).
2. The model module reference.
3. CUDA Graph captures (via :meth:`_release_cuda_graphs`).
4. Input processors.
5. Userbuffers (``ub.ub_deallocate`` per buffer); on per-buffer
failure the unfreed buffers are kept attached so a deterministic
retry doesn't double-free already-released ones, and the
collected errors are re-raised after the loop.

Idempotency:
Subsequent calls are no-ops (guarded by ``_cleanup_done``).
The flag is set only at the end, so a partial cleanup that
raises mid-way will be retried on the next call.

Raises:
RuntimeError: If one or more userbuffer deallocations fail
(chained from the first error). All other steps are
best-effort and either succeed or leak silently with
their errors logged at warning level by callees.

Called from:
- :meth:`PyExecutor.shutdown` (deterministic teardown).
- :meth:`__del__` (best-effort fallback during garbage
collection / interpreter shutdown).
"""
if getattr(self, "_cleanup_done", False):
return

# Cleanup is not truly atomic: released CUDA/GMS resources cannot be
# rolled back. Keep each handle live until its own release succeeds,
# so a failed cleanup can be retried without double-freeing resources
# that were already released.
model_loader = getattr(self, "model_loader", None)
if model_loader is not None:
model_loader.cleanup()
self.model_loader = None

self.model = None
self.model_loader = None

self._release_cuda_graphs()
self.input_processor = None
self.input_processor_with_hash = None
if getattr(self, 'ub_buffers', None):
for u in self.ub_buffers:
ub.ub_deallocate(u.addr)

ub_buffers = getattr(self, 'ub_buffers', None)
if ub_buffers:
remaining_ub_buffers = []
ub_errors = []
for u in ub_buffers:
try:
ub.ub_deallocate(u.addr)
except RuntimeError as e:
# Keep failed buffers attached so a deterministic
# cleanup() call can retry without double-freeing buffers
# that were already deallocated successfully.
remaining_ub_buffers.append(u)
ub_errors.append(e)
self.ub_buffers = remaining_ub_buffers or None
if ub_errors:
raise RuntimeError(
"Failed to deallocate one or more userbuffers during "
"PyTorchModelEngine cleanup") from ub_errors[0]

# Release model weights.
release_gc()
self._cleanup_done = True

def __del__(self) -> None:
"""Best-effort cleanup during garbage collection.

Delegates to :meth:`cleanup`. Catches ``RuntimeError`` (raised
when one or more userbuffer deallocations fail) and
``AttributeError`` (typical on partially-initialized engines
torn down during interpreter shutdown when module references
have already been cleared); both are logged and swallowed
because destructors cannot reliably surface exceptions.

Deterministic callers (``PyExecutor.shutdown``) should call
:meth:`cleanup` directly so they see any failure.
"""
try:
self.cleanup()
except (RuntimeError, AttributeError) as e:
logger.warning(
"PyTorchModelEngine cleanup failed during destruction: %s", e)

def _init_max_seq_len(self):
# Allow user to override the inferred max_seq_len with a warning.
Expand Down
Loading
Loading