Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions tensorizer/torch_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@

logger = logging.getLogger(__name__)

if hasattr(torch.serialization, "FILE_LIKE"):
# Pre torch 2.7.1
FileLike = torch.serialization.FILE_LIKE
else:
FileLike = torch.types.FileLike

_tensorizer_file_obj_type: "typing.TypeAlias" = Union[
io.BufferedIOBase,
io.RawIOBase,
Expand All @@ -67,7 +73,7 @@

_wrapper_file_obj_type: "typing.TypeAlias" = Union[
_tensorizer_file_obj_type,
Callable[[torch.types.FileLike], _tensorizer_file_obj_type],
Callable[[FileLike], _tensorizer_file_obj_type],
]

_save_func_type: "typing.TypeAlias" = Callable[
Expand Down Expand Up @@ -397,7 +403,7 @@ def _pickle_attr(name):
_ORIG_TORCH_LOAD: Final[callable] = torch.load


def _infer_tensor_ext_name(f: torch.types.FileLike):
def _infer_tensor_ext_name(f: FileLike):
if isinstance(f, io.BytesIO):
logger.warning(
"Cannot infer .tensors location from io.BytesIO;"
Expand All @@ -418,7 +424,7 @@ def _infer_tensor_ext_name(f: torch.types.FileLike):

@contextlib.contextmanager
def _contextual_torch_filename(
f: torch.types.FileLike,
f: FileLike,
filename_ctx_var: ContextVar[Optional[_wrapper_file_obj_type]],
):
if filename_ctx_var.get() is None:
Expand Down Expand Up @@ -462,7 +468,7 @@ def _contextual_torch_filename(
@functools.wraps(_ORIG_TORCH_SAVE)
def _save_wrapper(
obj: object,
f: torch.types.FileLike,
f: FileLike,
pickle_module: Any = pickle,
*args,
**kwargs,
Expand All @@ -489,7 +495,7 @@ def _save_wrapper(

@functools.wraps(_ORIG_TORCH_LOAD)
def _load_wrapper(
f: torch.types.FileLike,
f: FileLike,
map_location: torch.serialization.MAP_LOCATION = None,
pickle_module: Any = _LOAD_WRAPPER_DEFAULT_MODULE,
*args,
Expand Down Expand Up @@ -550,7 +556,7 @@ def tensorizer_saving(
that dynamically generates the file path or file object based on
the file path or file-like object ``f`` passed to the ``torch.save``
call. When using a callable, it should take a single argument of
the type ``torch.types.FileLike``, and output a type accepted
the type ``FileLike``, and output a type accepted
by a `TensorSerializer`. The default behaviour is to use a callable
that appends ``".tensors"`` to any filename passed as ``f``.
If a provided callable returns ``None``, tensorizer deserialization
Expand Down Expand Up @@ -620,7 +626,7 @@ def tensorizer_loading(
callable that dynamically generates the file path or file object
based on the file path or file-like object `f` passed to the
``torch.load`` call. When using a callable, it should take a single
argument of the type ``torch.types.FileLike``, and output a type
argument of the type ``FileLike``, and output a type
accepted by a `TensorDeserializer`. The default behaviour is to use
a callable that appends ``".tensors"`` to any filename passed as
``f``. If a provided callable returns ``None``, tensorizer
Expand Down