Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import math
import pydoc
import re
from concurrent.futures import Future, ThreadPoolExecutor, wait
from functools import reduce
from importlib import import_module
Expand Down Expand Up @@ -79,6 +80,7 @@ class XTunerBaseModelConfig(PydanticBaseModel):
"`dict[str, TorchCompileOption]`: Customize the compile option",
),
] = None
hf_key_mapping: Annotated[dict[str, str] | None, "Remapping hf key based on the `to_hf_key_list`"] = None

@property
def hf_config(self) -> PretrainedConfig | None:
Expand Down Expand Up @@ -381,10 +383,34 @@ def get_shard_placement(placements: tuple[Placement, ...]) -> Shard | None:
return

load_spec_mapping: dict[str, LoadSpec] = {}
hf_key_mapping_missing: set[str] = set()

for name, param in self.state_dict().items():
name = self._clean_param_name(name)
hf_keys = self.to_hf_key_list(name)
_hf_keys = self.to_hf_key_list(name)

if not self.config.hf_key_mapping:
hf_keys = _hf_keys
else:
hf_keys = []
for key in _hf_keys:
max_matched_pattern = None
max_match_len = -1
for pattern in self.config.hf_key_mapping:
if (matched := re.search(pattern, key)) is not None:
matched_len = matched.end() - matched.start()

if matched_len > max_match_len:
max_match_len = matched_len
max_matched_pattern = pattern

if max_matched_pattern is None:
hf_key_mapping_missing.add(key)
hf_keys.append(key)
else:
repl = self.config.hf_key_mapping[max_matched_pattern]
hf_keys.append(re.sub(max_matched_pattern, repl, key))

if isinstance(param, DTensor) and (placement := get_shard_placement(param.placements)) is not None:
dim = placement.dim
_, _offset = compute_local_shape_and_global_offset(param.shape, param.device_mesh, param.placements)
Expand Down Expand Up @@ -463,6 +489,10 @@ def get_shard_placement(placements: tuple[Placement, ...]) -> Shard | None:
)
load_spec_mapping[name] = load_spec

if hf_key_mapping_missing:
logger.info("These hf keys will not be influenced by `hf_key_mapping`:")
logger.info(json.dumps(list(hf_key_mapping_missing), indent=2))

self.load_spec_mapping = load_spec_mapping

def _to_float8(
Expand Down
8 changes: 6 additions & 2 deletions xtuner/v1/model/compose/intern_s1/intern_s1_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,14 @@ class InternS1Config(InternS1BaseConfig):
norm_type="rms_norm",
)
projector_config: InternS1ProjectorConfig = InternS1ProjectorConfig(vision_hidden_size=3200, text_hidden_size=4096)
text_config: MoEConfig = Qwen3MoE235BA22Config(vocab_size=153216)
text_config: MoEConfig = Qwen3MoE235BA22Config(
vocab_size=153216, hf_key_mapping={r"^model.": "model.language_model."}
)


class InternS1MiniConfig(InternS1BaseConfig):
vision_config: InternS1VisionConfig = InternS1VisionConfig()
projector_config: InternS1ProjectorConfig = InternS1ProjectorConfig()
text_config: Qwen3Dense8BConfig = Qwen3Dense8BConfig(vocab_size=153216)
text_config: Qwen3Dense8BConfig = Qwen3Dense8BConfig(
vocab_size=153216, hf_key_mapping={r"^model.": "model.language_model."}
)
8 changes: 0 additions & 8 deletions xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ def __init__(self, config: InternS1BaseConfig):
self.select_layer = config.vision_feature_layer
self.downsample_ratio = config.downsample_ratio

# TODO(YHC): This is a hack to make the language model compatible with HF
_hf_prefix = "model.language_model."
self.language_model.to_hf_key_list = types.MethodType(to_hf_key_list_wrapper( # type: ignore
fn=self.language_model.to_hf_key_list,
convertor=lambda x: x.replace('model.', _hf_prefix)),
self.language_model)
self.language_model._init_load_spec()

self.img_context_token_id = config.image_token_id
self.image_size = config.vision_config.image_size[0]

Expand Down
12 changes: 9 additions & 3 deletions xtuner/v1/model/compose/internvl/internvl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,22 @@ def hf_config(self):
class InternVL3P5Dense8BConfig(InternVLBaseConfig):
vision_config: InternVLVisionConfig = InternVLVisionConfig()
projector_config: InternVLProjectorConfig = InternVLProjectorConfig()
text_config: Qwen3Dense8BConfig = Qwen3Dense8BConfig()
text_config: Qwen3Dense8BConfig = Qwen3Dense8BConfig(
hf_key_mapping={r"^model.": "model.language_model."},
)


class InternVL3P5MoE30BA3Config(InternVLBaseConfig):
vision_config: InternVLVisionConfig = InternVLVisionConfig()
projector_config: InternVLProjectorConfig = InternVLProjectorConfig(text_hidden_size=2049)
text_config: Qwen3MoE30BA3Config = Qwen3MoE30BA3Config()
text_config: Qwen3MoE30BA3Config = Qwen3MoE30BA3Config(
hf_key_mapping={r"^model.": "model.language_model."},
)


class InternVL3P5Dense1BConfig(InternVLBaseConfig):
vision_config: InternVLVisionConfig = InternVLVisionConfig()
projector_config: InternVLProjectorConfig = InternVLProjectorConfig(text_hidden_size=1024)
text_config: Qwen3Dense0P6BConfig = Qwen3Dense0P6BConfig()
text_config: Qwen3Dense0P6BConfig = Qwen3Dense0P6BConfig(
hf_key_mapping={r"^model.": "model.language_model."},
)
16 changes: 8 additions & 8 deletions xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ class Qwen3VLForConditionalGeneration(BaseComposeModel):
def __init__(self, config: Qwen3VLBaseConfig):
super().__init__(config) # type: ignore[arg-type]

if type(self.language_model) is Qwen3MoE:
# TODO(YHC): This is a hack to make the language model compatible with HF
_hf_prefix = "model.language_model."
self.language_model.to_hf_key_list = types.MethodType(to_hf_key_list_wrapper( # type: ignore
fn=self.language_model.to_hf_key_list,
convertor=lambda x: x.replace('model.', _hf_prefix)),
self.language_model)
self.language_model._init_load_spec()
# if type(self.language_model) is Qwen3MoE:
# # TODO(YHC): This is a hack to make the language model compatible with HF
# _hf_prefix = "model.language_model."
# self.language_model.to_hf_key_list = types.MethodType(to_hf_key_list_wrapper( # type: ignore
# fn=self.language_model.to_hf_key_list,
# convertor=lambda x: x.replace('model.', _hf_prefix)),
# self.language_model)
# self.language_model._init_load_spec()

@property
@override
Expand Down