Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
## [2.12.0] - 2025-08-20

### Added

- Tensors with long dimensions (≥ `2 ** 32` elements in a single dimension)
can now be serialized and deserialized

### Changed

- `tensorizer.utils.CPUMemoryUsage.free` now reports available memory
rather than free memory

### Fixed

- `tensorizer.torch_compat` can now serialize and deserialize tensors that have
Expand Down Expand Up @@ -504,7 +509,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `get_gpu_name`
- `no_init_or_tensor`

[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.11.1...HEAD
[2.12.0]: https://github.com/coreweave/tensorizer/compare/v2.11.1...v2.12.0
[2.11.1]: https://github.com/coreweave/tensorizer/compare/v2.11.0...v2.11.1
[2.11.0]: https://github.com/coreweave/tensorizer/compare/v2.10.1...v2.11.0
[2.10.1]: https://github.com/coreweave/tensorizer/compare/v2.10.0...v2.10.1
Expand Down
2 changes: 1 addition & 1 deletion tensorizer/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.12.0a0"
__version__ = "2.12.0"
67 changes: 30 additions & 37 deletions tensorizer/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def is_long_tensor(self) -> bool:

class _FileFeatureFlags(enum.IntFlag):
encrypted = enum.auto()
versioned_headers = enum.auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -310,7 +311,7 @@ def from_io(
feature_flag_bytes, "little", signed=False
)
feature_flags = _FileFeatureFlags(feature_flag_int)
if not (0 <= feature_flags <= max(_FileFeatureFlags)):
if not (0 <= feature_flags < (max(_FileFeatureFlags) << 1)):
raise ValueError(
f"Unsupported feature flags: {_FileFeatureFlags!r}"
)
Expand Down Expand Up @@ -1838,8 +1839,17 @@ def __init__(
"Tensor is encrypted, but decryption was not requested"
)
self._has_versioned_metadata: bool = (
version_number >= LONG_TENSOR_TENSORIZER_VERSION
_FileFeatureFlags.versioned_headers in self._file_flags
)
if (
self._has_versioned_metadata
and version_number < LONG_TENSOR_TENSORIZER_VERSION
):
raise ValueError(
"Invalid feature flag present in file header for a file"
f" with version {version_number:d}"
f" (flags: {self._file_flags!s})"
)

# The total size of the file.
# WARNING: this is not accurate. This field isn't used in the
Expand Down Expand Up @@ -3904,6 +3914,10 @@ def _synchronize_metadata(self):
if pos + total_length > self._metadata_end:
raise RuntimeError("Metadata overflow")
self._pwrite_bulk(buffers, pos, total_length)
if self._metadata_handler.is_versioned:
self._file_header.feature_flags |= (
_FileFeatureFlags.versioned_headers
)

def _pwrite_bulk(
self, buffers: Sequence[bytes], offset: int, expected_length: int
Expand Down Expand Up @@ -3937,7 +3951,7 @@ class _MetadataHandler:
metadata entries need to be rewritten because of other entries written
in subsequent batches.

Internally, it functions like a state machine.
Internally, it has two states.

In its initial state, it tracks pending metadata entries to be written,
as well as past metadata entries that were already written. It stays
Expand All @@ -3948,55 +3962,32 @@ class _MetadataHandler:
Once it is given any tensor using a metadata scheme newer than V1,
it transitions to its second state. In this state, all
previously-written metadata entries are moved back into a pending state,
and version tags are prepended to every entry. It stays in this state
until the next write operation (i.e. the next call to ``commit()``),
after which it moves into its final state.

In its final state, no more history is saved for previously-written
metadata entries, as historical entries will at this point never again
need to be rewritten. Version tags continue to be prepended
to new entries. It remains in this state forever.
and version tags are prepended to every entry. No more history is saved
for newly-written metadata entries, as historical entries will at
this point never again need to be rewritten.
"""

__slots__ = ("pending", "past", "version", "_pos", "_state")
__slots__ = ("pending", "past", "version", "_pos", "_is_updated")
pending: list
past: list
version: int
_pos: int

class _MetadataHandlerState(enum.Enum):
TRACKING_PAST = 1
STAGING_PAST = 2
NO_PAST = 3

_state: _MetadataHandlerState
V1_TAG: ClassVar[bytes] = b"\x01\x00\x00\x00"

@property
def _tracking_past(self) -> bool:
return self._state is self._MetadataHandlerState.TRACKING_PAST

@property
def _staging_past(self) -> bool:
return self._state is self._MetadataHandlerState.STAGING_PAST

@property
def _no_past(self) -> bool:
return self._state is self._MetadataHandlerState.NO_PAST

def __init__(self):
self.pending = []
self.past = []
self.version = 1
self._pos = 0
self._state = self._MetadataHandlerState.TRACKING_PAST
self._is_updated = False

def submit(self, metadata: bytes, version: int):
if version > self.version:
if self.version == 1:
self._update()
self.version = version
if not self._tracking_past:
if self._is_updated:
self.pending.append(version.to_bytes(4, byteorder="little"))
self.pending.append(metadata)

Expand All @@ -4005,10 +3996,8 @@ def commit(self):
# Successive write positions are not a monotone sequence
pending = self.pending
self.pending = []
if self._tracking_past:
if not self._is_updated:
self.past.extend(pending)
elif self._staging_past:
self._state = self._MetadataHandlerState.NO_PAST
total_length = sum(len(d) for d in pending)
pos = self._pos
self._pos += total_length
Expand All @@ -4017,7 +4006,7 @@ def commit(self):
def _update(self):
# This is only called the one time that self.version is updated
# up from 1, so this should always be in the initial state
assert self._tracking_past
assert not self._is_updated
# At the time this is called, everything in self.past and
# self.pending must be version 1, so no complicated checking is
# needed to figure out what needs to be tagged with a v1 tag
Expand All @@ -4029,7 +4018,11 @@ def _update(self):
self.pending = pending
self.past.clear()
self._pos = 0
self._state = self._MetadataHandlerState.STAGING_PAST
self._is_updated = True

@property
def is_versioned(self) -> bool:
return self.version > 1

def write_tensor(
self,
Expand Down
2 changes: 1 addition & 1 deletion tensorizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def now(cls) -> "CPUMemoryUsage":
p.memory_info().rss for p in process.children(True)
)
vmem = psutil.virtual_memory()
return cls(maxrss, vmem.free)
return cls(maxrss, vmem.available)

def __str__(self):
return "CPU: (maxrss: {:,}MiB F: {:,}MiB)".format(
Expand Down