Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat

#### Alternative Downloads:

[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)

[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)

[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).

Expand Down
2 changes: 1 addition & 1 deletion api_server/routes/internal/internal_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def is_visible_file(entry: os.DirEntry) -> bool:
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)
return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200)


def get_app(self):
Expand Down
12 changes: 5 additions & 7 deletions comfy/ldm/ernie/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_ro
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

query, key = query.to(x.dtype), key.to(x.dtype)

q_flat = query.reshape(B, S, -1)
k_flat = key.reshape(B, S, -1)

Expand Down Expand Up @@ -161,16 +159,16 @@ def forward(self, x, rotary_pos_emb, temb, attention_mask=None):

residual = x
x_norm = self.adaLN_sa_ln(x)
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
x_norm = x_norm * (1 + scale_msa) + shift_msa

attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
x = residual + gate_msa * attn_out

residual = x
x_norm = self.adaLN_mlp_ln(x)
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
x_norm = x_norm * (1 + scale_mlp) + shift_mlp

return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
return residual + gate_mlp * self.mlp(x_norm)

class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
Expand All @@ -183,7 +181,7 @@ def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x)
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
return x

class ErnieImageModel(nn.Module):
Expand Down
55 changes: 6 additions & 49 deletions comfy/ldm/lightricks/vae/audio_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import torch
import torchaudio

import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
Expand Down Expand Up @@ -43,30 +40,6 @@ def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":

return cls(autoencoder=audio_config, vocoder=vocoder_config)


class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""

def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)

def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)

def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)

@property
def load_device(self):
return self.patcher.load_device


class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout."""

Expand Down Expand Up @@ -132,23 +105,17 @@ def waveform_to_mel(
class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points."""

def __init__(self, state_dict: dict, metadata: dict):
def __init__(self, metadata: dict):
super().__init__()

component_config = AudioVAEComponentConfig.from_metadata(metadata)

vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)

self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)

self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)

autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(
AudioPatchifier(
Expand All @@ -168,18 +135,12 @@ def __init__(self, state_dict: dict, metadata: dict):
n_fft=autoencoder_config["n_fft"],
)

self.device_manager = ModelDeviceManager(self)

def encode(self, audio: dict) -> torch.Tensor:
def encode(self, audio, sample_rate=44100) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors."""

waveform = audio["waveform"]
waveform_sample_rate = audio["sample_rate"]
waveform = audio
waveform_sample_rate = sample_rate
input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()

waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
if waveform.shape[1] == 1:
Expand All @@ -190,7 +151,7 @@ def encode(self, audio: dict) -> torch.Tensor:
)

mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device
waveform, waveform_sample_rate, device=waveform.device
)

latents = self.autoencoder.encode(mel_spec)
Expand All @@ -204,17 +165,13 @@ def decode(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape

# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()

latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents)

target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)

waveform = self.run_vocoder(mel_spec)
return self.device_manager.move_to_load_device(waveform)
return waveform

def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape
Expand Down
30 changes: 19 additions & 11 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def forward(self, x, emb):
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue

if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
Expand All @@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x)
return x

Expand Down Expand Up @@ -894,6 +895,12 @@ def _forward(self, x, timesteps=None, context=None, y=None, control=None, transf
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')

if "middle_block_after_patch" in transformer_patches:
patch = transformer_patches["middle_block_after_patch"]
for p in patch:
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
"timesteps": timesteps, "transformer_options": transformer_options})
h = out["h"]

for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
Expand All @@ -905,8 +912,9 @@ def _forward(self, x, timesteps=None, context=None, y=None, control=None, transf
for p in patch:
h, hsp = p(h, hsp, transformer_options)

h = th.cat([h, hsp], dim=1)
del hsp
if hsp is not None:
h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
Expand Down
Loading
Loading