Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,14 +1903,25 @@ def device(self) -> torch.device:
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
# Not cached: with group offloading, the effective device changes per-forward as groups onload/offload.
return get_parameter_device(self)

@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)
cached = self.__dict__.get("_cached_dtype")
if cached is not None:
return cached
cached = get_parameter_dtype(self)
self.__dict__["_cached_dtype"] = cached
return cached

def _apply(self, fn, *args, **kwargs):
# Invalidate cached dtype since `.to()`, `.cpu()`, `.cuda()`, `.half()`, etc. all flow through `_apply`.
self.__dict__.pop("_cached_dtype", None)
return super()._apply(fn, *args, **kwargs)

def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Expand Down
Loading