Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 139 additions & 139 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,9 +709,8 @@ def save_pretrained(
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format.
FlashPack is a binary format that allows for faster loading.
Requires the `flashpack` library to be installed.
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
Expand All @@ -732,22 +731,8 @@ def save_pretrained(
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)

weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)

os.makedirs(save_directory, exist_ok=True)
if use_flashpack:
if not is_main_process:
return
from ..utils.flashpack_utils import save_flashpack
save_flashpack(
self,
save_directory,
variant=variant,
)

if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", None)
Expand All @@ -759,67 +744,80 @@ def save_pretrained(
# Only save the model itself if we are using distributed training
model_to_save = self

# Attach architecture to the config
# Save the config
if is_main_process:
model_to_save.save_config(save_directory)

# Save the model
state_dict = model_to_save.state_dict()
if use_flashpack:
if not is_main_process:
return

# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
from ..utils.flashpack_utils import save_flashpack

# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)

for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
save_flashpack(model_to_save, save_directory, variant=variant)
else:
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)

if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
state_dict = model_to_save.state_dict()
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")

# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)

# Save each shard
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)

# Save index file if sharded
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")

# Push to hub if requested (common to both paths)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
Expand Down Expand Up @@ -953,9 +951,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`).
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
Expand Down Expand Up @@ -1218,13 +1216,50 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)

else:
flashpack_file = None
if use_flashpack:
flashpack_file = None
if use_flashpack:
try:
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant("model.flashpack", variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except EnvironmentError:
flashpack_file = None
logger.warning(
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
)

if flashpack_file is None:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
dduf_entries=dduf_entries,
)
elif use_safetensors:
logger.warning("Trying to load model weights with safetensors format.")
try:
flashpack_file = _get_model_file(
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant("model.flashpack", variant),
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
Expand All @@ -1236,68 +1271,33 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except EnvironmentError:
flashpack_file = None

if flashpack_file is None:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
dduf_entries=dduf_entries,
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
elif use_safetensors:
logger.warning("Trying to load model weights with safetensors format.")
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)

if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

if not isinstance(resolved_model_file, list):
resolved_model_file = [resolved_model_file]
if not isinstance(resolved_model_file, list):
resolved_model_file = [resolved_model_file]

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
Expand All @@ -1323,12 +1323,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

if flashpack_file is not None:
from ..utils.flashpack_utils import load_flashpack

# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
# explicitly materialize parameters before loading to ensure correctness and parity
# with the standard loading path.
if any(p.device.type == "meta" for p in model.parameters()):
model.to_empty(device="cpu")
model.to_empty(device="cpu")
load_flashpack(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
Expand Down Expand Up @@ -1434,12 +1435,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

if output_loading_info:
return model, loading_info

logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")

return model


# Adapted from `transformers`.
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ class implements both a save and loading method. The pipeline is easily reloaded
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install
flashpack`.
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
install flashpack`.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
Expand Down
Loading