Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 158 additions & 27 deletions comfy/ldm/lightricks/av_model.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions comfy/ldm/lightricks/embeddings_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
d_head,
context_dim=None,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
Expand All @@ -63,6 +64,7 @@ def __init__(
heads=n_heads,
dim_head=d_head,
context_dim=None,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(
positional_embedding_max_pos=[4096],
causal_temporal_positioning=False,
num_learnable_registers: Optional[int] = 128,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
Expand All @@ -145,6 +148,7 @@ def __init__(
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
Expand Down
186 changes: 158 additions & 28 deletions comfy/ldm/lightricks/model.py

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions comfy/ldm/lightricks/vae/audio_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE

LATENT_DOWNSAMPLE_FACTOR = 4

Expand Down Expand Up @@ -141,7 +141,10 @@ def __init__(self, state_dict: dict, metadata: dict):
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)

self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)

self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
Expand Down
63 changes: 27 additions & 36 deletions comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,26 +822,23 @@ def __init__(self, config=None):
super().__init__()

if config is None:
config = self._guess_config()
config = self.get_default_config()

# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})

self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
self.sampling_rate = model_config.get(
"sampling_rate", config.get("sampling_rate", 16000)
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)

# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)

# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT

Expand All @@ -850,44 +847,38 @@ def __init__(self, config=None):

self.per_channel_statistics = processor()

def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
def get_default_config(self):
ddconfig = {
"double_z": True,
"mel_bins": 64,
"z_channels": 8,
"resolution": 256,
"downsample_time": False,
"in_channels": 2,
"out_ch": 2,
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"ch_mult": [1, 2, 4],
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"attn_resolutions": [],
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"mid_block_add_attention": False,
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}

decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
"causality_axis": "height",
}

config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
"ddconfig": ddconfig,
"sampling_rate": 16000,
}
},
"preprocessing": {
"stft": {
"filter_length": 1024,
"hop_length": 160,
},
},
}

return config
Expand Down
48 changes: 39 additions & 9 deletions comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

ops = comfy.ops.disable_weight_init

def in_meta_context():
return torch.device("meta") == torch.empty(0).device

def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
Expand Down Expand Up @@ -350,6 +353,10 @@ def __init__(
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_space":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_time":
output_channel = output_channel * block_params.get("multiplier", 1)

self.conv_in = make_conv_nd(
dims,
Expand Down Expand Up @@ -395,17 +402,21 @@ def __init__(
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
Expand Down Expand Up @@ -455,6 +466,15 @@ def __init__(
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
else:
self.register_buffer(
"last_scale_shift_table",
torch.tensor(
[0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)


# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
Expand Down Expand Up @@ -883,6 +903,15 @@ def __init__(
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
else:
self.register_buffer(
"scale_shift_table",
torch.tensor(
[0.0, 0.0, 0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(4, in_channels),
persistent=False,
)

self.temporal_cache_state={}

Expand Down Expand Up @@ -1012,9 +1041,6 @@ def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))

def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
Expand All @@ -1027,9 +1053,12 @@ def __init__(self, version=0, config=None):
super().__init__()

if config is None:
config = self.guess_config(version)
config = self.get_default_config(version)

self.config = config
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
self.decode_timestep = config.get("decode_timestep", 0.05)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
Expand All @@ -1044,13 +1073,15 @@ def __init__(self, version=0, config=None):
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
base_channels=config.get("encoder_base_channels", 128),
)

self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
base_channels=config.get("decoder_base_channels", 128),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
Expand All @@ -1060,7 +1091,7 @@ def __init__(self, version=0, config=None):

self.per_channel_statistics = processor()

def guess_config(self, version):
def get_default_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
Expand Down Expand Up @@ -1167,8 +1198,7 @@ def encode(self, x):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)

def decode(self, x, timestep=0.05, noise_scale=0.025):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)

x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
Loading
Loading