Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- `tensorizer.torch_compat` is a new module for using `tensorizer` as a backend
for handling tensor data during standard `torch.save` and `torch.load` calls
- To use `tensorizer` as a backend for `torch.save`,
wrap the call in the `tensorizer_saving` context manager
- The file created must then be loaded using `tensorizer_loading`
- To use `tensorizer` as a backend for `torch.load`,
wrap the call in the `tensorizer_loading` context manager
- The file to load must have been created using `tensorizer_saving`

## [2.10.1] - 2025-06-27

### Fixed
Expand Down Expand Up @@ -472,6 +485,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `get_gpu_name`
- `no_init_or_tensor`

[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.10.1...HEAD
[2.10.1]: https://github.com/coreweave/tensorizer/compare/v2.10.0...v2.10.1
[2.10.0]: https://github.com/coreweave/tensorizer/compare/v2.9.3...v2.10.0
[2.9.3]: https://github.com/coreweave/tensorizer/compare/v2.9.2...v2.9.3
Expand Down
191 changes: 191 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,197 @@ An example command line tool to add or remove encryption from existing
serialized models is also available as
[examples/encryption.py](examples/encrypt_existing.py).

## PyTorch Compatibility

`tensorizer`'s `TensorSerializer` and `TensorDeserializer` classes are designed
to be able to replace the use of `torch.save` and `torch.load` in model saving
and loading pipelines, however, they are not drop-in replacements. The API for
serialization and deserialization with `tensorizer` offer more parameters to
control performance, resource usage, and additional features like encryption,
so they are invoked differently.
For drop-in replacements, see the next section.

The examples below show example usages of
`torch.save` and `torch.load`, and how they may be replaced with `tensorizer`
serialization.

```py
from tensorizer import TensorDeserializer, TensorSerializer
import torch

model: torch.nn.Module = ...

# Saving with torch.save
state_dict = model.state_dict()
torch.save(state_dict, "model.pt")

# Loading with torch.load
state_dict = torch.load("model.pt", map_location="cuda:0")
model.load_state_dict(state_dict)

# Saving with tensorizer.TensorSerializer
state_dict = model.state_dict()
serializer = TensorSerializer("model.tensors")
serializer.write_state_dict(state_dict)
serializer.close()

# Loading with tensorizer.TensorDeserializer
with TensorDeserializer("model.tensors", device="cuda:0") as state_dict:
model.load_state_dict(state_dict)
```

> [!NOTE]
>
> `TensorDeserializer` is a context manager because it supports lazy-loading,
> where the context controls how long its source file will remain open to read
> more tensors. This behaviour is optional and can be engaged by using
> `TensorDeserializer(..., lazy_load=True)`.

### Drop-In PyTorch Compatibility Layer, `tensorizer.torch_compat`

Note that, as `tensorizer` only serializes tensors and not other Python types,
it is more similar to `safetensors` than to `torch`'s own saving, as `torch`
bases its serialization on the `pickle` module, which allows serialization of
arbitrary Python objects.

The `tensorizer.torch_compat` module exists to address this and another common
integration challenge:
- Use case 1: You need to serialize Python objects other than tensors,
like `torch.save` does.
- Use case 2: You need to adapt existing code that uses `torch.save` internally
where it is not easy to swap out, like in an external framework or library.

**`tensorizer.torch_compat` enables calls to `torch.save` and `torch.load`
to use `tensorizer` as a backend for the serialization and deserialization
of tensor data, separate from other data being serialized.**

The interface to using `tensorizer.torch_compat` is through its two context
managers, `tensorizer_saving` and `tensorizer_loading`. These take similar
arguments to the `TensorSerializer` and `TensorDeserializer` classes,
respectively, and temporarily swap out the `torch.save` and `torch.load`
functions to ones with special behaviour while their context is active.
Saving this way produces two files, one for tensors, and one for all other data.

```py
import torch
from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving

model: torch.nn.Module = ...

state_dict = model.state_dict()

# Saving with torch.save, internally using tensorizer.TensorSerializer
with tensorizer_saving("model.pt.tensors"):
torch.save(state_dict, "model.pt")

# Loading with torch.load, internally using tensorizer.TensorDeserializer
with tensorizer_loading("model.pt.tensors", device="cuda:0"):
state_dict = torch.load("model.pt")
model.load_state_dict(state_dict)
```

For existing code that uses `torch.save` or `torch.load` internally, the
recommended usage pattern is to wrap the relevant section of code in one of
these context managers so that it can use `tensorizer` automatically.

For instance, with a `transformers.Trainer` object, part of adapting it to
use `tensorizer` may be:

```py
from tensorizer.torch_compat import tensorizer_saving

with tensorizer_saving():
# In case this module saves references to torch.save at import time
import transformers

trainer: transformers.Trainer = ...

with tensorizer_saving():
# This method may call torch.save internally at some point,
# so activating this context around it will intercept it when it does
trainer.train()
```

#### `torch_compat` Usage Considerations

If the filename to use is difficult to determine in advance, the first
`file_obj` argument to `tensorizer_loading` and `tensorizer_saving` is allowed
to be a callback that receives the path passed to `torch.save` and returns
a place to output the sidecar `.tensors` file.

The `.tensors` path can be anything supported normally in `tensorizer`,
including pre-opened file-like objects and `s3://` URIs.
The default `file_obj` callback simply appends `.tensors` to the path.

```py
import torch
from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving


def tensors_path(f: torch.types.FileLike) -> str | None:
if isinstance(f, str):
return f.replace(".pt", "-tensor-data.tensors", 1)
else:
# Returning None will save normally, without using tensorizer.
# This is useful for file-like objects like io.BytesIO,
# where sidecar files don't make sense.
return None


model: torch.nn.Module = ...
state_dict = model.state_dict()

with tensorizer_saving(tensors_path):
# Will save to model.pt and model-tensor-data.tensors
torch.save(state_dict, "model.pt")

with tensorizer_loading(tensors_path, device="cuda:0"):
# Will load from model.pt and model-tensor-data.tensors
state_dict = torch.load("model.pt")
model.load_state_dict(state_dict)
```

The `tensorizer_saving` and `tensorizer_loading` contexts are also thread-safe
and async-safe, in that their effects are local to one thread and coroutine.
They may also be activated at the same time as each other, or even nested
to temporarily change the arguments one is using.

> [!WARNING]
>
> Even though `tensorizer` itself only handles data and does not execute
> arbitrary code, `torch.load` still uses the `pickle` module internally.
> Loading untrusted `pickle` files **can** execute arbitrary code, so take
> appropriate precautions when using these wrappers.
>
> Additionally, for technical reasons, `torch.load(..., weights_only=True)`
> is incompatible with these wrappers. `weights_only` can be forced to `False`
> by using `tensorizer_loading(..., suppress_weights_only=True)`,
> but this disables some safety checks in `torch`, so this is opt-in only.

Finally, since the `tensorizer_saving` and `tensorizer_loading` contexts
temporarily swap out the `torch.save` and `torch.load` functions, note that they
will not affect already-saved references to those functions, e.g.:

```py
from tensorizer.torch_compat import tensorizer_saving
from torch import save as original_torch_save

with tensorizer_saving():
# This won't work, but torch.save(..., "model.pt") would work
original_torch_save(..., "model.pt")
```

This can sometimes be worked around by wrapping import blocks
in `tensorizer_saving` and/or `tensorizer_loading` as well.
The wrappers will behave the same as the default `torch.save` and `torch.load`
functions unless their respective contexts are active, so this will usually
have no side effects.

For additional parameters, caveats, and advanced usage information,
refer to the docstrings for `tensorizer_saving` and `tensorizer_loading` in
the file [tensorizer/torch_compat.py](/tensorizer/torch_compat.py),
or view their function documentation inline in an IDE.

## Benchmarks

You can run your own benchmarks on CoreWeave or your own Kubernetes cluster
Expand Down
3 changes: 2 additions & 1 deletion tensorizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from . import serialization, stream_io, utils
from . import serialization, stream_io, torch_compat, utils
from ._version import __version__
from .serialization import *

__all__ = [
*serialization.__all__,
"stream_io",
"torch_compat",
"utils",
"protobuf",
"tensors_pb2",
Expand Down
2 changes: 1 addition & 1 deletion tensorizer/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.10.1"
__version__ = "2.11.0a0"
Loading