Skip to content

Commit 943b3b6

Browse files
comfyanonymouskijairattus128
authored
HunyuanVideo 1.5 (Comfy-Org#10819)
* init * update * Update model.py * Update model.py * remove print * Fix text encoding * Prevent empty negative prompt Really doesn't work otherwise * fp16 works * I2V * Update model_base.py * Update nodes_hunyuan.py * Better latent rgb factors * Use the correct sigclip output... * Support HunyuanVideo1.5 SR model * whitespaces... * Proper latent channel count * SR model fixes This also still needs timesteps scheduling based on the noise scale, can be used with two samplers too already * vae_refiner: roll the convolution through temporal Work in progress. Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. * Support HunyuanVideo15 latent resampler * fix * Some cleanup Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Proper hyvid15 I2V channels Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Fix TokenRefiner for fp16 Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary. * Bugfix for the HunyuanVideo15 SR model * vae_refiner: roll the convolution through temporal II Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. Added support for encoder, lowered to 1 latent frame to save more VRAM, made work for Hunyuan Image 3.0 (as code shared). Fixed names, cleaned up code. * Allow any number of input frames in VAE. * Better VAE encode mem estimation. * Lowvram fix. * Fix hunyuan image 2.1 refiner. * Fix mistake. * Name changes. * Rename. * Whitespace. * Fix. * Fix. --------- Co-authored-by: kijai <40791699+kijai@users.noreply.github.com> Co-authored-by: Rattus <rattus128@gmail.com>
1 parent 10e90a5 commit 943b3b6

15 files changed

Lines changed: 779 additions & 128 deletions

File tree

comfy/latent_formats.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,66 @@ class HunyuanImage21Refiner(LatentFormat):
611611
latent_dimensions = 3
612612
scale_factor = 1.03682
613613

614+
def process_in(self, latent):
615+
out = latent * self.scale_factor
616+
out = torch.cat((out[:, :, :1], out), dim=2)
617+
out = out.permute(0, 2, 1, 3, 4)
618+
b, f_times_2, c, h, w = out.shape
619+
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
620+
out = out.permute(0, 2, 1, 3, 4).contiguous()
621+
return out
622+
623+
def process_out(self, latent):
624+
z = latent / self.scale_factor
625+
z = z.permute(0, 2, 1, 3, 4)
626+
b, f, c, h, w = z.shape
627+
z = z.reshape(b, f, 2, c // 2, h, w)
628+
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
629+
z = z.permute(0, 2, 1, 3, 4)
630+
z = z[:, :, 1:]
631+
return z
632+
633+
class HunyuanVideo15(LatentFormat):
634+
latent_rgb_factors = [
635+
[ 0.0568, -0.0521, -0.0131],
636+
[ 0.0014, 0.0735, 0.0326],
637+
[ 0.0186, 0.0531, -0.0138],
638+
[-0.0031, 0.0051, 0.0288],
639+
[ 0.0110, 0.0556, 0.0432],
640+
[-0.0041, -0.0023, -0.0485],
641+
[ 0.0530, 0.0413, 0.0253],
642+
[ 0.0283, 0.0251, 0.0339],
643+
[ 0.0277, -0.0372, -0.0093],
644+
[ 0.0393, 0.0944, 0.1131],
645+
[ 0.0020, 0.0251, 0.0037],
646+
[-0.0017, 0.0012, 0.0234],
647+
[ 0.0468, 0.0436, 0.0203],
648+
[ 0.0354, 0.0439, -0.0233],
649+
[ 0.0090, 0.0123, 0.0346],
650+
[ 0.0382, 0.0029, 0.0217],
651+
[ 0.0261, -0.0300, 0.0030],
652+
[-0.0088, -0.0220, -0.0283],
653+
[-0.0272, -0.0121, -0.0363],
654+
[-0.0664, -0.0622, 0.0144],
655+
[ 0.0414, 0.0479, 0.0529],
656+
[ 0.0355, 0.0612, -0.0247],
657+
[ 0.0147, 0.0264, 0.0174],
658+
[ 0.0438, 0.0038, 0.0542],
659+
[ 0.0431, -0.0573, -0.0033],
660+
[-0.0162, -0.0211, -0.0406],
661+
[-0.0487, -0.0295, -0.0393],
662+
[ 0.0005, -0.0109, 0.0253],
663+
[ 0.0296, 0.0591, 0.0353],
664+
[ 0.0119, 0.0181, -0.0306],
665+
[-0.0085, -0.0362, 0.0229],
666+
[ 0.0005, -0.0106, 0.0242]
667+
]
668+
669+
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
670+
latent_channels = 32
671+
latent_dimensions = 3
672+
scale_factor = 1.03682
673+
614674
class Hunyuan3Dv2(LatentFormat):
615675
latent_channels = 64
616676
latent_dimensions = 1

comfy/ldm/hunyuan_video/model.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import comfy.ldm.modules.diffusionmodules.mmdit
77
from comfy.ldm.modules.attention import optimized_attention
88

9-
109
from dataclasses import dataclass
1110
from einops import repeat
1211

@@ -42,6 +41,8 @@ class HunyuanVideoParams:
4241
guidance_embed: bool
4342
byt5: bool
4443
meanflow: bool
44+
use_cond_type_embedding: bool
45+
vision_in_dim: int
4546

4647

4748
class SelfAttentionRef(nn.Module):
@@ -157,7 +158,10 @@ def forward(
157158
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
158159
# m = mask.float().unsqueeze(-1)
159160
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
160-
c = x.sum(dim=1) / x.shape[1]
161+
if x.dtype == torch.float16:
162+
c = x.float().sum(dim=1) / x.shape[1]
163+
else:
164+
c = x.sum(dim=1) / x.shape[1]
161165

162166
c = t + self.c_embedder(c.to(x.dtype))
163167
x = self.input_embedder(x)
@@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module):
196200
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
197201
super().__init__()
198202
self.dtype = dtype
203+
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
204+
199205
params = HunyuanVideoParams(**kwargs)
200206
self.params = params
201207
self.patch_size = params.patch_size
202208
self.in_channels = params.in_channels
203209
self.out_channels = params.out_channels
210+
self.use_cond_type_embedding = params.use_cond_type_embedding
211+
self.vision_in_dim = params.vision_in_dim
204212
if params.hidden_size % params.num_heads != 0:
205213
raise ValueError(
206214
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -266,6 +274,18 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
266274
if final_layer:
267275
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
268276

277+
# HunyuanVideo 1.5 specific modules
278+
if self.vision_in_dim is not None:
279+
from comfy.ldm.wan.model import MLPProj
280+
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
281+
else:
282+
self.vision_in = None
283+
if self.use_cond_type_embedding:
284+
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
285+
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
286+
else:
287+
self.cond_type_embedding = None
288+
269289
def forward_orig(
270290
self,
271291
img: Tensor,
@@ -276,6 +296,7 @@ def forward_orig(
276296
timesteps: Tensor,
277297
y: Tensor = None,
278298
txt_byt5=None,
299+
clip_fea=None,
279300
guidance: Tensor = None,
280301
guiding_frame_index=None,
281302
ref_latent=None,
@@ -331,12 +352,31 @@ def forward_orig(
331352

332353
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
333354

355+
if self.cond_type_embedding is not None:
356+
self.cond_type_embedding.to(txt.device)
357+
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
358+
txt = txt + cond_emb.to(txt.dtype)
359+
334360
if self.byt5_in is not None and txt_byt5 is not None:
335361
txt_byt5 = self.byt5_in(txt_byt5)
362+
if self.cond_type_embedding is not None:
363+
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
364+
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
365+
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
366+
else:
367+
txt = torch.cat((txt, txt_byt5), dim=1)
336368
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
337-
txt = torch.cat((txt, txt_byt5), dim=1)
338369
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
339370

371+
if clip_fea is not None:
372+
txt_vision_states = self.vision_in(clip_fea)
373+
if self.cond_type_embedding is not None:
374+
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
375+
txt_vision_states = txt_vision_states + cond_emb
376+
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
377+
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
378+
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
379+
340380
ids = torch.cat((img_ids, txt_ids), dim=1)
341381
pe = self.pe_embedder(ids)
342382

@@ -430,20 +470,20 @@ def img_ids_2d(self, x):
430470
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
431471
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
432472

433-
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
473+
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
434474
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
435475
self._forward,
436476
self,
437477
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
438-
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
478+
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
439479

440-
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
480+
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
441481
bs = x.shape[0]
442482
if len(self.patch_size) == 3:
443483
img_ids = self.img_ids(x)
444484
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
445485
else:
446486
img_ids = self.img_ids_2d(x)
447487
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
448-
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
488+
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
449489
return out
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
5+
import model_management, model_patcher
6+
7+
class SRResidualCausalBlock3D(nn.Module):
8+
def __init__(self, channels: int):
9+
super().__init__()
10+
self.block = nn.Sequential(
11+
VideoConv3d(channels, channels, kernel_size=3),
12+
nn.SiLU(inplace=True),
13+
VideoConv3d(channels, channels, kernel_size=3),
14+
nn.SiLU(inplace=True),
15+
VideoConv3d(channels, channels, kernel_size=3),
16+
)
17+
18+
def forward(self, x: torch.Tensor) -> torch.Tensor:
19+
return x + self.block(x)
20+
21+
class SRModel3DV2(nn.Module):
22+
def __init__(
23+
self,
24+
in_channels: int,
25+
out_channels: int,
26+
hidden_channels: int = 64,
27+
num_blocks: int = 6,
28+
global_residual: bool = False,
29+
):
30+
super().__init__()
31+
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
32+
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
33+
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
34+
self.global_residual = bool(global_residual)
35+
36+
def forward(self, x: torch.Tensor) -> torch.Tensor:
37+
residual = x
38+
y = self.in_conv(x)
39+
for blk in self.blocks:
40+
y = blk(y)
41+
y = self.out_conv(y)
42+
if self.global_residual and (y.shape == residual.shape):
43+
y = y + residual
44+
return y
45+
46+
47+
class Upsampler(nn.Module):
48+
def __init__(
49+
self,
50+
z_channels: int,
51+
out_channels: int,
52+
block_out_channels: tuple[int, ...],
53+
num_res_blocks: int = 2,
54+
):
55+
super().__init__()
56+
self.num_res_blocks = num_res_blocks
57+
self.block_out_channels = block_out_channels
58+
self.z_channels = z_channels
59+
60+
ch = block_out_channels[0]
61+
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
62+
63+
self.up = nn.ModuleList()
64+
65+
for i, tgt in enumerate(block_out_channels):
66+
stage = nn.Module()
67+
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
68+
out_channels=tgt,
69+
temb_channels=0,
70+
conv_shortcut=False,
71+
conv_op=VideoConv3d, norm_op=RMS_norm)
72+
for j in range(num_res_blocks + 1)])
73+
ch = tgt
74+
self.up.append(stage)
75+
76+
self.norm_out = RMS_norm(ch)
77+
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
78+
79+
def forward(self, z):
80+
"""
81+
Args:
82+
z: (B, C, T, H, W)
83+
target_shape: (H, W)
84+
"""
85+
# z to block_in
86+
repeats = self.block_out_channels[0] // (self.z_channels)
87+
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
88+
89+
# upsampling
90+
for stage in self.up:
91+
for blk in stage.block:
92+
x = blk(x)
93+
94+
out = self.conv_out(F.silu(self.norm_out(x)))
95+
return out
96+
97+
UPSAMPLERS = {
98+
"720p": SRModel3DV2,
99+
"1080p": Upsampler,
100+
}
101+
102+
class HunyuanVideo15SRModel():
103+
def __init__(self, model_type, config):
104+
self.load_device = model_management.vae_device()
105+
offload_device = model_management.vae_offload_device()
106+
self.dtype = model_management.vae_dtype(self.load_device)
107+
self.model_class = UPSAMPLERS.get(model_type)
108+
self.model = self.model_class(**config).eval()
109+
110+
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
111+
112+
def load_sd(self, sd):
113+
return self.model.load_state_dict(sd, strict=True)
114+
115+
def get_sd(self):
116+
return self.model.state_dict()
117+
118+
def resample_latent(self, latent):
119+
model_management.load_model_gpu(self.patcher)
120+
return self.model(latent.to(self.load_device))

0 commit comments

Comments
 (0)