Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions tests/model/test_intern_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,6 @@ def test_fsdp_text_accuracy(self, device, tol):
)
data_mesh = None

interns1_model.language_model.fully_shard(fsdp_config=fsdp_config)
interns1_model.vision_tower.fully_shard(fsdp_config=fsdp_config)
interns1_model.multi_modal_projector.fully_shard(fsdp_config=fsdp_config)
interns1_model.fully_shard(fsdp_config=fsdp_config)

interns1_model.from_hf(INTERNS1_DENSE_PATH)
Expand Down Expand Up @@ -356,9 +353,6 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
data_mesh = init_data_mesh(device, sp_size=sp_size)
sp_mesh = data_mesh["sp"]

interns1_model.language_model.fully_shard(fsdp_config=fsdp_config)
interns1_model.vision_tower.fully_shard(fsdp_config=fsdp_config)
interns1_model.multi_modal_projector.fully_shard(fsdp_config=fsdp_config)
interns1_model.fully_shard(fsdp_config=fsdp_config)

interns1_model.from_hf(INTERNS1_DENSE_PATH)
Expand Down Expand Up @@ -412,9 +406,6 @@ def test_save_hf(self, device, tp_size):
syncdir = [tmpdir]
dist.broadcast_object_list(syncdir, src=0)
tmpdir = Path(syncdir[0])
interns1_model.language_model.fully_shard(fsdp_config=fsdp_config)
interns1_model.vision_tower.fully_shard(fsdp_config=fsdp_config)
interns1_model.multi_modal_projector.fully_shard(fsdp_config=fsdp_config)
interns1_model.fully_shard(fsdp_config=fsdp_config)
interns1_model.from_hf(INTERNS1_DENSE_PATH)
interns1_model.save_hf(tmpdir)
Expand Down
30 changes: 4 additions & 26 deletions tests/model/test_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
from xtuner.v1.utils.compile import maybe_compile
from xtuner.v1.utils.test_utils import init_data_mesh
from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig
from torch.distributed.fsdp import (
MixedPrecisionPolicy,
fully_shard,
)


QWEN3_VL_MOE_PATH = os.environ["QWEN3_VL_MOE_PATH"]
QWEN3_VL_DENSE_PATH = os.environ["QWEN3_VL_DENSE_PATH"]
Expand Down Expand Up @@ -117,7 +114,6 @@ def _test_all(self, hf_model, qwen3vl_model, type, device, sp_size, tol):
expected_loss = output.loss
dist.all_reduce(expected_loss.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)

hf_model.to('cpu')
torch.cuda.empty_cache()

loss_cfg = CELossConfig()
Expand Down Expand Up @@ -155,7 +151,6 @@ def _test_all(self, hf_model, qwen3vl_model, type, device, sp_size, tol):
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
)
qwen3vl_model.to('cpu')
torch.cuda.empty_cache()
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand All @@ -175,7 +170,7 @@ def test_qwen3vl_run(self, device, sp_size, tol):
QWEN3_VL_DENSE_PATH,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="cpu"
device_map="cuda"
).eval()
patch_hf_rms_norm(hf_model)

Expand All @@ -185,7 +180,6 @@ def test_qwen3vl_run(self, device, sp_size, tol):

qwen3vl_model.from_hf(QWEN3_VL_DENSE_PATH)
qwen3vl_model.eval()
qwen3vl_model.to('cpu')

self._test_all(hf_model, qwen3vl_model, 'text', device, sp_size, tol)
self._test_all(hf_model, qwen3vl_model, 'image', device, sp_size, tol)
Expand All @@ -207,7 +201,7 @@ def test_fsdp_qwen3_run(self, device, sp_size, compile, tol):
QWEN3_VL_DENSE_PATH,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="cpu"
device_map="cuda"
).eval()
patch_hf_rms_norm(hf_model)

Expand All @@ -222,28 +216,15 @@ def test_fsdp_qwen3_run(self, device, sp_size, compile, tol):
torch_compile=compile
)

qwen3vl_model.language_model.fully_shard(fsdp_config=fsdp_config)

# 非常神奇,一旦开了这个,image 和 video 的单测就过不了。
# qwen3vl_model.vision_tower.fully_shard(fsdp_config=fsdp_config)
# 将整个 vit 打包为一个大的 FSDP module 就完全一致。实际跑发现对每一层进行 FSDP 切分会导致每次计算有细微差异
fsdp_mesh = init_world_mesh()
mp_policy = MixedPrecisionPolicy(param_dtype=fsdp_config.param_dtype)
fully_shard(
qwen3vl_model.vision_tower,
mesh=fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=True
)
qwen3vl_model.vision_tower.fsdp_mesh = fsdp_mesh
qwen3vl_model.vision_tower.fsdp_config = fsdp_config

qwen3vl_model.multi_modal_projector.fully_shard(fsdp_config=fsdp_config)
qwen3vl_model.fully_shard(fsdp_config=fsdp_config)

qwen3vl_model.from_hf(QWEN3_VL_DENSE_PATH)
qwen3vl_model.eval()
qwen3vl_model.to('cpu')

self._test_all(hf_model, qwen3vl_model, 'text', device, sp_size, tol)
self._test_all(hf_model, qwen3vl_model, 'image', device, sp_size, tol)
self._test_all(hf_model, qwen3vl_model, 'video', device, sp_size, tol)
Expand All @@ -270,9 +251,6 @@ def test_save_hf(self, device, tp_size):
syncdir = [tmpdir]
dist.broadcast_object_list(syncdir, src=0)
tmpdir = Path(syncdir[0])
qwen3vl_model.language_model.fully_shard(fsdp_config=fsdp_config)
qwen3vl_model.vision_tower.fully_shard(fsdp_config=fsdp_config)
qwen3vl_model.multi_modal_projector.fully_shard(fsdp_config=fsdp_config)
qwen3vl_model.fully_shard(fsdp_config=fsdp_config)
qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH)
qwen3vl_model.save_hf(tmpdir)
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def build_model(self) -> BaseModel:
scaling_granularity_gemm=self.model_cfg.float8_cfg.scaling_granularity_gemm,
scaling_granularity_grouped_gemm=self.model_cfg.float8_cfg.scaling_granularity_grouped_gemm,
)
model = model.fully_shard(self.fsdp_cfg, self.float8_handler)
model = model.fully_shard(self.fsdp_cfg)

if dist.get_rank() == 0:
logger.info(model)
Expand Down Expand Up @@ -220,7 +220,7 @@ def grad_accumulation_steps(self, data_batches_len: int):

# this method can be called outside, e.g., at the beginning of compute_actor_logprobs or compute_ref_logprobs during rl training
def maybe_precompute_float8_dynamic_scale_for_fsdp(self):
if self.float8_handler is not None and self.float8_handler.enabled:
if self.float8_handler is not None:
self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model)

def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
Expand Down
17 changes: 8 additions & 9 deletions xtuner/v1/engine/vision_compose_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,20 @@ def build_model(self) -> BaseComposeModel: # type: ignore
scaling_granularity_grouped_gemm=self.model_cfg.projector_config.float8_cfg.scaling_granularity_grouped_gemm,
)

model.language_model.fully_shard(self.fsdp_cfg, self.llm_float8_handler)
model.vision_tower.fully_shard(self.fsdp_cfg, self.vision_float8_handler)
model.multi_modal_projector.fully_shard(self.fsdp_cfg, self.projector_float8_handler)
model = model.fully_shard(self.fsdp_cfg)

if dist.get_rank() == 0:
logger.info(model)

if self.llm_float8_handler:
if self.llm_float8_handler is not None:
self.llm_float8_handler.build_reduce_mesh(
model.language_model, cast(DeviceMesh, model.language_model.fsdp_mesh)
)
if self.vision_float8_handler:
if self.vision_float8_handler is not None:
self.vision_float8_handler.build_reduce_mesh(
model.vision_tower, cast(DeviceMesh, model.vision_tower.fsdp_mesh)
)
if self.projector_float8_handler:
if self.projector_float8_handler is not None:
self.projector_float8_handler.build_reduce_mesh(
model.multi_modal_projector, cast(DeviceMesh, model.multi_modal_projector.fsdp_mesh)
)
Expand All @@ -100,11 +97,13 @@ def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16):

# this method can be called outside, e.g., at the beginning of compute_actor_logprobs or compute_ref_logprobs during rl training
def maybe_precompute_float8_dynamic_scale_for_fsdp(self):
if self.llm_float8_handler is not None and self.llm_float8_handler.enabled:
if self.llm_float8_handler is not None:
self.llm_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.language_model)
if self.vision_float8_handler is not None and self.vision_float8_handler.enabled:

if self.vision_float8_handler:
self.vision_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.vision_tower)
if self.projector_float8_handler is not None and self.projector_float8_handler.enabled:

if self.projector_float8_handler:
self.projector_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.multi_modal_projector)

def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
Expand Down
24 changes: 3 additions & 21 deletions xtuner/v1/float8/float8_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
scaling_granularity_gemm: Optional[ScalingGranularity] = None,
scaling_granularity_grouped_gemm: Optional[ScalingGranularity] = None,
) -> None:
self.enabled = False
torch.serialization.add_safe_globals(
[
WeightWithDynamicTilewiseFloat8CastTensor,
Expand Down Expand Up @@ -77,7 +76,6 @@ def __init__(
or scaling_granularity_grouped_gemm == ScalingGranularity.TILEWISE
)
self.is_tensorwise_fp8 = scaling_granularity_gemm == ScalingGranularity.TENSORWISE
self.enabled = True

@staticmethod
def get_num_features_after_pad(tensor_size, fsdp_shard_dim, num_chunks):
Expand Down Expand Up @@ -107,15 +105,12 @@ def get_num_features_after_pad(tensor_size, fsdp_shard_dim, num_chunks):
break
return chunk_size * num_chunks

def pad_for_fsdp(self, model: nn.Module, fsdp_mesh: DeviceMesh, callback_after_pad: Callable | None = None):
@staticmethod
def pad_for_fsdp(model: nn.Module, fsdp_mesh: DeviceMesh, callback_after_pad: Callable | None = None):
from xtuner.v1.float8.float8_gmm_tile_wise import TileWiseFloat8GroupedLinear
from xtuner.v1.float8.float8_linear_tensor_wise import TensorWiseFloat8Linear
from xtuner.v1.float8.float8_linear_tile_wise import TileWiseFloat8Linear

if not self.enabled:
logger.warning("Float8 training is not enabled.")
return

for module in model.modules():
if isinstance(module, (TileWiseFloat8Linear, TileWiseFloat8GroupedLinear, TensorWiseFloat8Linear)):
# make fsdp compatible with block-wise fp8
Expand All @@ -129,18 +124,14 @@ def pad_for_fsdp(self, model: nn.Module, fsdp_mesh: DeviceMesh, callback_after_p
else:
tensor_size = module.weight.size()
parallel_size = 1
padded_out_features = self.get_num_features_after_pad(tensor_size, 0, fsdp_mesh.size(-1))
padded_out_features = Float8Handler.get_num_features_after_pad(tensor_size, 0, fsdp_mesh.size(-1))
padded_out_features *= parallel_size
module.pad_for_fsdp(padded_out_features=padded_out_features)

if callback_after_pad is not None:
callback_after_pad()

def build_reduce_mesh(self, model: nn.Module, fsdp_mesh: DeviceMesh):
if not self.enabled:
logger.warning("Float8 training is not enabled.")
return

self.fsdp_mesh = fsdp_mesh
if self.is_tilewise_fp8:
if fsdp_mesh.size(-1) >= 2:
Expand All @@ -151,9 +142,6 @@ def _build_reduce_mesh_devided_64(self, fsdp_mesh: DeviceMesh):
# 为了支持 moe 参数被 fsdp 和 ep 切成 dout = n * 128 + 64 (n >= 1) 的情况
# fsdp rank 0 的后 64 个 dim 要跟 fsdp rank 1 的前 64 个 dim 共同组成一个 block
# 计算 absmax 的时候要 reduce max
if not self.enabled:
logger.warning("Float8 training is not enabled.")
return
if not self.is_tilewise_fp8:
logger.warning("Scaling granularity is not TILEWISE, no need to build reduce group.")
return
Expand All @@ -174,9 +162,6 @@ def _build_reduce_mesh_devided_64(self, fsdp_mesh: DeviceMesh):
self.tilewise_reduce_mesh_devided_64 = device_mesh

def _build_reduce_mesh_mapping(self, model: nn.Module, fsdp_mesh: DeviceMesh):
if not self.enabled:
logger.warning("Float8 training is not enabled.")
return
if not self.is_tilewise_fp8:
logger.warning("Scaling granularity is not TILEWISE, no need to build reduce group.")
return
Expand Down Expand Up @@ -225,9 +210,6 @@ def _build_reduce_mesh_mapping(self, model: nn.Module, fsdp_mesh: DeviceMesh):
self.tilewise_reduce_mesh_mapping = tilewise_reduce_mesh_mapping

def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, List[nn.Module]]):
if not self.enabled:
return

models = [model] if isinstance(model, nn.Module) else model

for m in models:
Expand Down
4 changes: 1 addition & 3 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from xtuner.v1.config import FSDPConfig, GenerateConfig
from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.float8.config import Float8Config
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.float8.fsdp_utils import (
WeightWithDynamicTensorWiseFloat8CastTensor,
WeightWithDynamicTilewiseFloat8CastTensor,
Expand Down Expand Up @@ -274,8 +273,7 @@ def trainable_parameters(self):
def fully_shard(
self,
fsdp_config: FSDPConfig,
float8_handler: Float8Handler | None = None,
) -> "BaseModel":
) -> Self:
"""Fully shard the model parameters."""
raise NotImplementedError

Expand Down
10 changes: 6 additions & 4 deletions xtuner/v1/model/compose/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Callable
from typing import Callable, Self

import torch
import torch.distributed as dist
Expand All @@ -15,7 +15,6 @@
from typing_extensions import override

from xtuner.v1.config import FSDPConfig
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.model import BaseModel
from xtuner.v1.model.base import XTunerBaseModelConfig
from xtuner.v1.utils import get_device, get_logger
Expand Down Expand Up @@ -96,9 +95,12 @@ def init_weights(self) -> None:
def fully_shard(
self,
fsdp_config: FSDPConfig,
float8_handler: Float8Handler | None = None,
):
) -> Self:
self.fsdp_config = fsdp_config
self.language_model.fully_shard(self.fsdp_config)
self.vision_tower.fully_shard(self.fsdp_config)
self.multi_modal_projector.fully_shard(self.fsdp_config)

# TODO: 判断其余模块是否已经被 fsdp 切分了

mp_policy = MixedPrecisionPolicy(param_dtype=fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype)
Expand Down
8 changes: 5 additions & 3 deletions xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import types
from typing import cast
from typing import cast, Self

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -63,9 +63,11 @@ def __init__(self, config: InternS1BaseConfig):
def fully_shard(
self,
fsdp_config: FSDPConfig,
float8_handler: Float8Handler | None = None,
):
) -> Self:
self.fsdp_config = fsdp_config
self.language_model.fully_shard(self.fsdp_config)
self.vision_tower.fully_shard(self.fsdp_config)
self.multi_modal_projector.fully_shard(self.fsdp_config)
# TODO: 判断其余模块是否已经被 fsdp 切分了

# NOTE: 暂时只能在这个地方进行 checkpoint_wrapper
Expand Down
6 changes: 2 additions & 4 deletions xtuner/v1/model/compose/intern_s1/modeling_projector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing_extensions import override
from typing_extensions import override, Self
from torch import nn
import torch

Expand Down Expand Up @@ -52,10 +52,8 @@ def to_hf_key_list(self, key: str) -> list[str]:
def fully_shard(
self,
fsdp_config: FSDPConfig,
float8_handler: Float8Handler | None = None,
):
) -> Self:
self.fsdp_config = fsdp_config
assert float8_handler is None
mp_policy = MixedPrecisionPolicy(
param_dtype=fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
Expand Down
7 changes: 2 additions & 5 deletions xtuner/v1/model/compose/intern_s1/modeling_vision.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial
from torch import nn
import torch
from typing import Union, Optional, Callable
from typing import Union, Optional, Callable, Self
from typing_extensions import override
import numpy as np

Expand Down Expand Up @@ -343,11 +343,8 @@ def to_hf_key_list(self, key: str) -> list[str]:
def fully_shard(
self,
fsdp_config: FSDPConfig,
float8_handler: Float8Handler | None = None,
):
) -> Self:
self.fsdp_config = fsdp_config
assert float8_handler is None

checkpoint_preserve_rng_state = fsdp_config.checkpoint_preserve_rng_state
if not checkpoint_preserve_rng_state and self.config.drop_path_rate > 0.0:
checkpoint_preserve_rng_state = True
Expand Down
2 changes: 0 additions & 2 deletions xtuner/v1/model/compose/qwen3_vl/modeling_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ def to_hf_key_list(self, key: str) -> list[str]:
def fully_shard(
self,
fsdp_config: FSDPConfig,
float8_handler: Float8Handler | None = None,
):
self.fsdp_config = fsdp_config
assert float8_handler is None
mp_policy = MixedPrecisionPolicy(
param_dtype=fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
Expand Down
Loading