Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion comfy/ldm/ace/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import nn

import comfy.model_management
import comfy.patcher_extension

from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
Expand Down Expand Up @@ -343,7 +344,28 @@ def decode(
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output

def forward(
def forward(self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
controlnet_scale, lyrics_strength, **kwargs)

def _forward(
self,
x,
timestep,
Expand Down
8 changes: 8 additions & 0 deletions comfy/ldm/aura/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.patcher_extension
import comfy.ldm.common_dit

def modulate(x, shift, scale):
Expand Down Expand Up @@ -436,6 +437,13 @@ def apply_pos_embeds(self, x, h, w):
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])

def forward(self, x, timestep, context, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, transformer_options, **kwargs)

def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape
Expand Down
8 changes: 8 additions & 0 deletions comfy/ldm/chroma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.patcher_extension
import comfy.ldm.common_dit

from comfy.ldm.flux.layers import (
Expand Down Expand Up @@ -253,6 +254,13 @@ def block_wrap(args):
return img

def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)

def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))

Expand Down
38 changes: 38 additions & 0 deletions comfy/ldm/cosmos/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from enum import Enum
import logging

import comfy.patcher_extension

from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
Expand Down Expand Up @@ -435,6 +437,42 @@ def forward(
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x,
timesteps,
context,
attention_mask,
fps,
image_size,
padding_mask,
scalar_feature,
data_type,
latent_condition,
latent_condition_sigma,
condition_video_augment_sigma,
**kwargs)

def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Args:
Expand Down
17 changes: 16 additions & 1 deletion comfy/ldm/cosmos/predict2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms

import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention

def apply_rotary_pos_emb(
Expand Down Expand Up @@ -805,7 +806,21 @@ def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
)
return x_B_C_Tt_Hp_Wp

def forward(
def forward(self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, fps, padding_mask, **kwargs)

def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
Expand Down
8 changes: 8 additions & 0 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension

from .layers import (
DoubleStreamBlock,
Expand Down Expand Up @@ -214,6 +215,13 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0):
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)

def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)

def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size

Expand Down
19 changes: 18 additions & 1 deletion comfy/ldm/hidream/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.patcher_extension
import comfy.ldm.common_dit


Expand Down Expand Up @@ -692,7 +693,23 @@ def patchify(self, x, max_seq, img_sizes=None):
raise NotImplementedError
return x, x_masks, img_sizes

def forward(
def forward(self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
image_cond=None,
control = None,
transformer_options = {},
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)

def _forward(
self,
x: torch.Tensor,
t: torch.Tensor,
Expand Down
8 changes: 8 additions & 0 deletions comfy/ldm/hunyuan3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
SingleStreamBlock,
timestep_embedding,
)
import comfy.patcher_extension


class Hunyuan3Dv2(nn.Module):
Expand Down Expand Up @@ -67,6 +68,13 @@ def __init__(
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)

def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, transformer_options, **kwargs)

def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context
Expand Down
8 changes: 8 additions & 0 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#Based on Flux code because of weird hunyuan video code license.

import torch
import comfy.patcher_extension
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
Expand Down Expand Up @@ -348,6 +349,13 @@ def img_ids(self, x):
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)

def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)

def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
Expand Down
8 changes: 8 additions & 0 deletions comfy/ldm/lightricks/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
Expand Down Expand Up @@ -420,6 +421,13 @@ def __init__(self,
self.patchifier = SymmetricPatchifier(1)

def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)

def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})

orig_shape = list(x.shape)
Expand Down
10 changes: 9 additions & 1 deletion comfy/ldm/lumina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.patcher_extension


def modulate(x, scale):
Expand Down Expand Up @@ -590,8 +591,15 @@ def patchify_and_embed(

return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis

# def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)

# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
Expand Down
12 changes: 7 additions & 5 deletions comfy/ldm/modules/diffusionmodules/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(self, x):
def modulate(x, shift, scale):
if shift is None:
shift = torch.zeros_like(scale)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))


#################################################################################
Expand Down Expand Up @@ -564,10 +564,7 @@ def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
Expand All @@ -594,6 +591,11 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
)
return self.post_attention(attn, *intermediates)

def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
return x

def block_mixing(*args, use_checkpoint=True, **kwargs):
if use_checkpoint:
Expand Down
Loading
Loading