Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
553 changes: 553 additions & 0 deletions lightx2v/models/networks/base_model.py

Large diffs are not rendered by default.

227 changes: 38 additions & 189 deletions lightx2v/models/networks/hunyuan_video/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import gc
import glob
import os

import torch
import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger
from safetensors import safe_open

from lightx2v.models.networks.base_model import BaseTransformerModel
from lightx2v.models.networks.hunyuan_video.infer.feature_caching.transformer_infer import HunyuanTransformerInferTeaCaching, HunyuanVideo15TransformerInferMagCaching
from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer
from lightx2v.models.networks.hunyuan_video.infer.post_infer import HunyuanVideo15PostInfer
Expand All @@ -16,41 +11,17 @@
from lightx2v.models.networks.hunyuan_video.weights.post_weights import HunyuanVideo15PostWeights
from lightx2v.models.networks.hunyuan_video.weights.pre_weights import HunyuanVideo15PreWeights
from lightx2v.models.networks.hunyuan_video.weights.transformer_weights import HunyuanVideo15TransformerWeights
from lightx2v.utils.custom_compiler import CompiledMethodsMixin
from lightx2v.utils.envs import *
from lightx2v.utils.custom_compiler import compiled_method


class HunyuanVideo15Model(BaseTransformerModel):
pre_weight_class = HunyuanVideo15PreWeights
transformer_weight_class = HunyuanVideo15TransformerWeights
post_weight_class = HunyuanVideo15PostWeights

class HunyuanVideo15Model(CompiledMethodsMixin):
def __init__(self, model_path, config, device):
super().__init__()
self.model_path = model_path
self.config = config
self.device = device
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.remove_keys = []
super().__init__(model_path, config, device)
self.remove_keys.extend(["byt5_in", "vision_in"])
self.dit_quantized = self.config.get("dit_quantized", False)
if self.dit_quantized:
assert self.config.get("dit_quant_scheme", "Default") in [
"fp8-vllm",
"int8-vllm",
"fp8-q8f",
"int8-q8f",
"fp8-b128-deepgemm",
"fp8-sgl",
"int8-sgl",
"int8-torchao",
"fp8-torchao",
"nvfp4",
"mxfp4",
"mxfp6-mxfp8",
"mxfp8",
]
self._init_infer_class()
self._init_weights()
self._init_infer()
Expand All @@ -67,136 +38,53 @@ def _init_infer_class(self):
else:
raise NotImplementedError

def _init_weights(self):
unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
sensitive_layer = {}
if not self.dit_quantized:
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)

self.original_weight_dict = weight_dict
self.pre_weight = HunyuanVideo15PreWeights(self.config)
self.transformer_weights = HunyuanVideo15TransformerWeights(self.config)
self.post_weight = HunyuanVideo15PostWeights(self.config)
self._apply_weights()

def _apply_weights(self, weight_dict=None):
if weight_dict is not None:
self.original_weight_dict = weight_dict
del weight_dict
gc.collect()
# Load weights into containers
self.pre_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)

del self.original_weight_dict
torch.cuda.empty_cache()
gc.collect()

def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)

def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)

def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
self._init_offload_manager()

if self.config.get("dit_quantized_ckpt", None):
safetensors_path = self.config["dit_quantized_ckpt"]
else:
safetensors_path = self.model_path

if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
safetensors_path = os.path.dirname(safetensors_path)

weight_dict = {}
for safetensor_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == safetensor_path:
continue
with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
if any(remove_key in k for remove_key in remove_keys):
continue
if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).to(self.device)

if self.config.get("dit_quant_scheme", "Default") == "nvfp4":
calib_path = os.path.join(safetensors_path, "calib.pt")
logger.info(f"[CALIB] Loaded calibration data from: {calib_path}")
calib_data = torch.load(calib_path, map_location="cpu")
for k, v in calib_data["absmax"].items():
weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device)
@torch.no_grad()
def _infer_cond_uncond(self, inputs, infer_condition=True):
self.scheduler.infer_condition = infer_condition

return weight_dict
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)

def _load_ckpt(self, unified_dtype, sensitive_layer):
if self.config.get("dit_original_ckpt", None):
safetensors_path = self.config["dit_original_ckpt"]
else:
safetensors_path = self.config["transformer_model_path"]
if self.config["seq_parallel"]:
pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)

if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)

weight_dict = {}
for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == file_path:
continue
logger.info(f"Loading weights from {file_path}")
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
if self.config["seq_parallel"]:
x = self._seq_parallel_post_process(x)

return weight_dict
noise_pred = self.post_infer.infer(x, pre_infer_out)[0]

def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
return noise_pred

if self.device.type != "cpu" and dist.is_initialized():
device = dist.get_rank()
else:
device = str(self.device)
@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
seqlen = pre_infer_out.img.shape[1]
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)

with safe_open(file_path, framework="pt", device=device) as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
for key in f.keys()
if not any(remove_key in key for remove_key in remove_keys)
}
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
pre_infer_out.img = F.pad(pre_infer_out.img, (0, 0, 0, padding_size))

def to_cpu(self):
self.pre_weight.to_cpu()
self.transformer_weights.to_cpu()
pre_infer_out.img = torch.chunk(pre_infer_out.img, world_size, dim=1)[cur_rank]
return pre_infer_out

def to_cuda(self):
self.pre_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
def _seq_parallel_post_process(self, x):
world_size = dist.get_world_size(self.seq_p_group)
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(gathered_x, x, group=self.seq_p_group)
combined_output = torch.cat(gathered_x, dim=1)
return combined_output

@compiled_method()
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload:
Expand Down Expand Up @@ -238,42 +126,3 @@ def infer(self, inputs):
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.transformer_weights.non_block_weights_to_cpu()

@torch.no_grad()
def _infer_cond_uncond(self, inputs, infer_condition=True):
self.scheduler.infer_condition = infer_condition

pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)

if self.config["seq_parallel"]:
pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)

x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)

if self.config["seq_parallel"]:
x = self._seq_parallel_post_process(x)

noise_pred = self.post_infer.infer(x, pre_infer_out)[0]

return noise_pred

@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
seqlen = pre_infer_out.img.shape[1]
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)

padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
pre_infer_out.img = F.pad(pre_infer_out.img, (0, 0, 0, padding_size))

pre_infer_out.img = torch.chunk(pre_infer_out.img, world_size, dim=1)[cur_rank]
return pre_infer_out

@torch.no_grad()
def _seq_parallel_post_process(self, x):
world_size = dist.get_world_size(self.seq_p_group)
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(gathered_x, x, group=self.seq_p_group)
combined_output = torch.cat(gathered_x, dim=1)
return combined_output
Loading