Skip to content

Conversation

@Eta0
Copy link
Collaborator

@Eta0 Eta0 commented Jul 29, 2025

Compatibility module for torch.save and torch.load

This change adds a new compatibility module to tensorizer, tensorizer.torch_compat. It provides an interface, implemented with context managers, for using torch.save and torch.load with tensorizer as their backend for the serialization of tensors and tensor storages, while leaving serialization and deserialization of all other objects and metadata to the respective torch functions. A brief description and usage pointers for this module are in the changelog, copied below.

Changelog

  • tensorizer.torch_compat is a new module for using tensorizer as a backend for handling tensor data during standard torch.save and torch.load calls
    • To use tensorizer as a backend for torch.save, wrap the call in the tensorizer_saving context manager
      • The file created must then be loaded using tensorizer_loading
    • To use tensorizer as a backend for torch.load, wrap the call in the tensorizer_loading context manager
      • The file to load must have been created using tensorizer_saving

Highlights

This new module provides several advantages over using tensorizer directly:

  • Easier integration with other library code already using torch.save and torch.load
  • Support for restoring relationships between separate tensors with tied weights upon deserialization
  • Support for serializing objects other than tensors, to the same extent as they are supported by torch.save

Not supported

tensorizer.torch_compat.tensorizer_loading() does not support device selection following the same rules as the map_location argument to torch.load().

Usage

This module is intended to be very simple to use, to make it easy to adapt existing code using torch serialization functions to make use of tensorizer. An instance of torch.nn.Module can be serialized as follows:

import os
import torch
from tensorizer.torch_compat import tensorizer_saving, tensorizer_loading

module: torch.nn.Module = ...

with tensorizer_saving():
    torch.save(module, "module.pt")

assert os.path.exists("module.pt")
assert os.path.exists("module.pt.tensors")

with tensorizer_loading(device="cuda", num_readers=4):
    deserialized_module = torch.load("module.pt")

There are only two public functions in torch_compat: tensorizer_saving(), and tensorizer_loading(). The API of these functions is thoroughly documented in their docstrings. Several supported use cases are demonstrated in the test suite. It is copied below, for reference.

tensorizer_saving()

Context manager that modifies calls to ``torch.save`` to use tensorizer
as a backend for the serialization of tensors and tensor storages.

Tensors are saved in a sidecar file separate from the ``.pt`` file created
by ``torch.save``. To load them again, use the `tensorizer_loading`
context manager paired with ``torch.load``.

Notes:
    This context manager is thread-safe and async-safe. Other threads or
    coroutines executing concurrently while this context is active will not
    be modified.

Args:
    file_obj: The file or file-like object in which to save tensor data,
        separate from the one passed to ``torch.save`` for saving metadata.
        This can be any type accepted by a `TensorSerializer`, or a callable
        that dynamically generates the file path or file object based on
        the file path or file-like object ``f`` passed to the ``torch.save``
        call. When using a callable, it should take a single argument of
        the type ``torch.types.FileLike``, and output a type accepted
        by a `TensorSerializer`. The default behaviour is to use a callable
        that appends ``".tensors"`` to any filename passed as ``f``.
        If a provided callable returns ``None``, tensorizer deserialization
        is not used.
    save_func: An optional callable with the signature
        ``save_func(file_obj, tensors: Iterable[Tensor], kwargs: dict)``
        that may be used to override the default saving logic for tensors.
        `file_obj` and `kwargs` correspond to the ones passed to this
        function. This may be used, for instance, to make serialization
        asynchronous by writing a `save_func` that serializes in
        a background thread or process.
    kwargs: Further keyword arguments to pass to the `TensorSerializer`
        object used to save tensor data.

tensorizer_loading()

Context manager that modifies calls to ``torch.load`` to use tensorizer
as a backend for the deserialization of tensors and tensor storages.
This is only valid to use when deserializing files that were serialized
using the corresponding `tensorizer_saving` context manager paired with
``torch.save``.

Tensors are saved in a sidecar file separate from the ``.pt`` file created
by ``torch.save``. Both must be present at deserialization time.

Notes:
    This context manager is thread-safe and async-safe. Other threads or
    coroutines executing concurrently while this context is active will not
    be modified.

Args:
    file_obj: The file or file-like object from which to load tensor data,
        separate from the one passed to ``torch.load`` for loading metadata.
        This can be any type accepted by a `TensorDeserializer`, or a
        callable that dynamically generates the file path or file object
        based on the file path or file-like object `f` passed to the
        ``torch.load`` call. When using a callable, it should take a single
        argument of the type ``torch.types.FileLike``, and output a type
        accepted by a `TensorDeserializer`. The default behaviour is to use
        a callable that appends ``".tensors"`` to any filename passed as
        ``f``. If a provided callable returns ``None``, tensorizer
        serialization is not used.
    load_func: An optional callable with the signature
        ``load_func(file_obj, kwargs: dict) -> Iterable[Tensor]``
        that may be used to override the default loading logic for tensors.
        `file_obj` and `kwargs` correspond to the ones passed to this
        function.
    suppress_weights_only: If set to ``True``, replace ``weights_only=True``
        with ``weights_only=False`` in calls to ``torch.load`` within this
        context. Using ``torch.load`` with tensorizer as a backend is
        incompatible with ``weights_only=True`` because ``torch`` counts it
        using a custom ``pickle_module`` as being a non-weights-only load,
        even though tensorizer only loads weights in practice.
    kwargs: Further keyword arguments to pass to the `TensorDeserializer`
        object used to load tensor data.

Async checkpointing example

This module is compatible with async saving. The boilerplate for implementing simple async checkpointing in a transformers.Trainer training loop is shown below. This utilizes the save_func parameter to tensorizer_saving to involve a thread pool in the call to the serializer, managed by the calling code.

from concurrent.futures import ThreadPoolExecutor
import tensorizer

with tensorizer.torch_compat.tensorizer_saving():
    from transformers import Trainer, TrainingArguments

pool = ThreadPoolExecutor(max_workers=1)


def save_func(file_obj, tensors, kwargs):
    serializer = tensorizer.TensorSerializer(file_obj, **kwargs)
    serializer.write_state_dict(tensors)
    serializer.close()


def async_save(file_obj, tensors, kwargs):
    # Clone tensors so that they don't get updated
    # while a write operation is still ongoing.
    tensors = [t.detach().cpu() for t in tensors]
    pool.submit(save_func, file_obj, tensors, kwargs)


# Any trainer setup code can go here
training_args = TrainingArguments(..., save_safetensors=False)
trainer = Trainer(..., args=training_args)


with tensorizer.torch_compat.tensorizer_saving(
    save_func=async_save, limit_cpu_concurrency=1
):
    try:
        trainer.train()
    finally:
        pool.shutdown(wait=True)

Object storage support

These context managers support object storage paths. While providing an exact path is easiest, this can sometimes be eased by use of the filename callback argument to tensorizer_saving() and tensorizer_loading(). Below is an example that uploads only tensor weights to object storage, while saving metadata at a local path. It operates entirely based on a callback that converts a filename passed to torch.save into a corresponding s3:// URI. A similar hook could be added to a save_func to also convert local file paths for metadata to s3:// URIs. (Note: none of this complexity is necessary if the calling code can choose file paths directly—using callbacks in this way is mainly for messing with third-party library code using torch.save and torch.load in inaccessible locations).

import os
import tensorizer
import torch

BUCKET_NAME: str = "example"


def object_storage_path(f: torch.types.FileLike) -> str:
    filename: str = os.fsdecode(f) + ".tensors"
    return f"s3://{BUCKET_NAME}/{os.path.relpath(filename, ".")}"


with tensorizer.torch_compat.tensorizer_saving(object_storage_path):
    ...

Version update

This PR also updates the code version to 2.11.0a0. The full update to 2.11.0 for the release of this module will come in a subsequent PR.

Future work

This module is subject to the same issue as torch.save with overly large storages backing tensors: full storages are always saved for all tensors present, even if only small views of them actually appear in the tensors being saved. This is part of torch's implementation both incidentally and to support tied weights between separate tensors. To also support tied weights, we inherit this behaviour, but we have a way planned out to make this less of an issue. That work isn't finished yet, so it is not included in this PR.

Eta0 added 6 commits July 24, 2025 00:04
This module allows using tensorizer as a backend for handling
tensor data in calls to torch.save and torch.load by way of
the context managers tensorizer_saving and tensorizer_loading.
This also adds torch_compat to __init__.py,
and refactors the ContextVar import.
@Eta0 Eta0 self-assigned this Jul 29, 2025
@Eta0 Eta0 added the enhancement New feature or request label Jul 29, 2025
@Eta0 Eta0 marked this pull request as ready for review July 29, 2025 16:35
This includes activating both tensorizer_saving and tensorizer_loading
at the same time, as well as nesting multiple levels of either
tensorizer_saving or tensorizer_loading with different arguments
for each nested context.
@Eta0 Eta0 requested a review from wbrown July 30, 2025 19:30
Copy link
Contributor

@wbrown wbrown left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic work, Eta.

My recommendation here would be to mention this in the README.md as well. I know that you believe in docstrings, but that by itself will not attract attention to Tensorizer.

It needs to be front and center, and advertised as an example and feature there. That's what search engines and people read first.

@wbrown wbrown merged commit 34f578c into main Aug 4, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants