Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions modelopt/torch/opt/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from contextlib import contextmanager

import torch
import transformers
from packaging.version import Version
from transformers import PreTrainedModel, Trainer, TrainerCallback
from transformers import modeling_utils as tf_modeling_utils

Expand Down Expand Up @@ -130,13 +132,18 @@ def _save_pretrained_with_checks(self, save_directory, *args, **kwargs):

# [Fix for huggingface bug] deepspeed zero3 training backend only loads params into the model from
# state_dict, but not buffers. So lets explicitly load the buffers into the model from state_dict.
# The `load_config` parameter was added to `_load_state_dict_into_zero3_model` in transformers 5.0.
_TRANSFORMERS_GE_5_0 = Version(transformers.__version__) >= Version("5.0")


def _load_params_and_buffers_into_zero3_model(model_to_load, state_dict, load_config=None):
buffer_names = [name for name, _ in model_to_load.named_buffers()]
buffer_state_dict = {k: v for k, v in state_dict.items() if k in buffer_names}
model_to_load.load_state_dict(buffer_state_dict, strict=False)
return tf_modeling_utils._modelopt_cache["_load_state_dict_into_zero3_model"](
model_to_load, state_dict, load_config
)
cached_fn = tf_modeling_utils._modelopt_cache["_load_state_dict_into_zero3_model"]
if _TRANSFORMERS_GE_5_0:
return cached_fn(model_to_load, state_dict, load_config)
return cached_fn(model_to_load, state_dict)


pretrained_model_patch_methods = [
Expand Down
Loading