Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,104 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
model.eval()

if device_map is not None:
# Fix: handle remaining meta tensors before dispatch.
# When loading from single-file checkpoints (e.g., GGUF), some model parameters
# or buffers may not be present in the checkpoint and remain on the meta device
# after `load_model_dict_into_meta`. This causes `dispatch_model` to fail with
# "Cannot copy out of meta tensor" errors.
if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device

_target_device = list(device_map.values())[0] if device_map else "cpu"

# Collect submodules that have non-persistent buffers still on meta.
# These buffers are computed during __init__ (e.g., RoPE sinusoidal embeddings)
# and were never saved in the checkpoint. We need to re-materialize them by
# re-creating the submodule outside of init_empty_weights context.
_modules_to_reinit = set()
for _name, _buf in list(model.named_buffers()):
if _buf.device == torch.device("meta"):
_parts = _name.rsplit(".", 1)
_parent_path = _parts[0]
_buf_name = _parts[-1]
_submodule = model.get_submodule(_parent_path)
_is_persistent = _buf_name not in _submodule._non_persistent_buffers_set
if not _is_persistent:
_modules_to_reinit.add(_parent_path)
else:
logger.warning(
f"Buffer '{_name}' is still on meta device after loading. "
f"Initializing to zeros."
)
set_module_tensor_to_device(
model, _name, _target_device,
value=torch.zeros(
_buf.shape, dtype=_buf.dtype, device=_target_device
),
)

# Re-create submodules with non-persistent meta buffers.
# This recomputes deterministic buffers (e.g., sinusoidal positional embeddings)
# that were created as meta tensors under init_empty_weights().
for _mod_path in _modules_to_reinit:
_submodule = model.get_submodule(_mod_path)
_parent_path, _child_name = _mod_path.rsplit(".", 1) if "." in _mod_path else ("", _mod_path)
_parent = model.get_submodule(_parent_path) if _parent_path else model
_cls = type(_submodule)

# Reconstruct using config attributes stored on the submodule
_init_args = {}
import inspect
_sig = inspect.signature(_cls.__init__)
for _param_name, _param in _sig.parameters.items():
if _param_name == "self":
continue
if hasattr(_submodule, _param_name):
_init_args[_param_name] = getattr(_submodule, _param_name)
elif _param.default is not inspect.Parameter.empty:
pass # use default
else:
break # can't reconstruct, fall back below
else:
try:
logger.info(
f"Re-creating submodule '{_mod_path}' ({_cls.__name__}) "
f"to materialize non-persistent buffers."
)
_new_submodule = _cls(**_init_args)
setattr(_parent, _child_name, _new_submodule)
continue
except Exception as e:
logger.warning(
f"Failed to re-create '{_mod_path}': {e}. "
f"Falling back to zero-initialization."
)

# Fallback: zero-init any remaining meta buffers in this submodule
for _name, _buf in list(model.named_buffers()):
if _name.startswith(_mod_path) and _buf.device == torch.device("meta"):
set_module_tensor_to_device(
model, _name, _target_device,
value=torch.zeros(
_buf.shape, dtype=_buf.dtype, device=_target_device
),
)

# Handle any remaining meta parameters (should not happen with correct key mapping)
for _name, _param in list(model.named_parameters()):
if _param.device == torch.device("meta"):
logger.warning(
f"Parameter '{_name}' is still on meta device after loading. "
f"This likely indicates an incomplete checkpoint key mapping. "
f"Initializing to zeros."
)
set_module_tensor_to_device(
model, _name, _target_device,
value=torch.zeros(
_param.shape, dtype=_param.dtype, device=_target_device
),
)

device_map_kwargs = {"device_map": device_map}
dispatch_model(model, **device_map_kwargs)

Expand Down