Skip to content

Commit 43c64b6

Browse files
Support the LTXAV 2.3 model. (Comfy-Org#12773)
1 parent ac4a943 commit 43c64b6

File tree

10 files changed

+957
-131
lines changed

10 files changed

+957
-131
lines changed

comfy/ldm/lightricks/av_model.py

Lines changed: 158 additions & 27 deletions
Large diffs are not rendered by default.

comfy/ldm/lightricks/embeddings_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
d_head,
5151
context_dim=None,
5252
attn_precision=None,
53+
apply_gated_attention=False,
5354
dtype=None,
5455
device=None,
5556
operations=None,
@@ -63,6 +64,7 @@ def __init__(
6364
heads=n_heads,
6465
dim_head=d_head,
6566
context_dim=None,
67+
apply_gated_attention=apply_gated_attention,
6668
dtype=dtype,
6769
device=device,
6870
operations=operations,
@@ -121,6 +123,7 @@ def __init__(
121123
positional_embedding_max_pos=[4096],
122124
causal_temporal_positioning=False,
123125
num_learnable_registers: Optional[int] = 128,
126+
apply_gated_attention=False,
124127
dtype=None,
125128
device=None,
126129
operations=None,
@@ -145,6 +148,7 @@ def __init__(
145148
num_attention_heads,
146149
attention_head_dim,
147150
context_dim=cross_attention_dim,
151+
apply_gated_attention=apply_gated_attention,
148152
dtype=dtype,
149153
device=device,
150154
operations=operations,

comfy/ldm/lightricks/model.py

Lines changed: 158 additions & 28 deletions
Large diffs are not rendered by default.

comfy/ldm/lightricks/vae/audio_vae.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
CausalityAxis,
1414
CausalAudioAutoencoder,
1515
)
16-
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
16+
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
1717

1818
LATENT_DOWNSAMPLE_FACTOR = 4
1919

@@ -141,7 +141,10 @@ def __init__(self, state_dict: dict, metadata: dict):
141141
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
142142

143143
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
144-
self.vocoder = Vocoder(config=component_config.vocoder)
144+
if "bwe" in component_config.vocoder:
145+
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
146+
else:
147+
self.vocoder = Vocoder(config=component_config.vocoder)
145148

146149
self.autoencoder.load_state_dict(vae_sd, strict=False)
147150
self.vocoder.load_state_dict(vocoder_sd, strict=False)

comfy/ldm/lightricks/vae/causal_audio_autoencoder.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -822,26 +822,23 @@ def __init__(self, config=None):
822822
super().__init__()
823823

824824
if config is None:
825-
config = self._guess_config()
825+
config = self.get_default_config()
826826

827-
# Extract encoder and decoder configs from the new format
828827
model_config = config.get("model", {}).get("params", {})
829-
variables_config = config.get("variables", {})
830828

831-
self.sampling_rate = variables_config.get(
832-
"sampling_rate",
833-
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
829+
self.sampling_rate = model_config.get(
830+
"sampling_rate", config.get("sampling_rate", 16000)
834831
)
835832
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
836833
decoder_config = model_config.get("decoder", encoder_config)
837834

838835
# Load mel spectrogram parameters
839836
self.mel_bins = encoder_config.get("mel_bins", 64)
840-
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
841-
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
837+
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
838+
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
842839

843840
# Store causality configuration at VAE level (not just in encoder internals)
844-
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
841+
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
845842
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
846843
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
847844

@@ -850,44 +847,38 @@ def __init__(self, config=None):
850847

851848
self.per_channel_statistics = processor()
852849

853-
def _guess_config(self):
854-
encoder_config = {
855-
# Required parameters - based on ltx-video-av-1679000 model metadata
850+
def get_default_config(self):
851+
ddconfig = {
852+
"double_z": True,
853+
"mel_bins": 64,
854+
"z_channels": 8,
855+
"resolution": 256,
856+
"downsample_time": False,
857+
"in_channels": 2,
858+
"out_ch": 2,
856859
"ch": 128,
857-
"out_ch": 8,
858-
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
860+
"ch_mult": [1, 2, 4],
859861
"num_res_blocks": 2,
860-
"attn_resolutions": [], # Based on metadata: empty list, no attention
862+
"attn_resolutions": [],
861863
"dropout": 0.0,
862-
"resamp_with_conv": True,
863-
"in_channels": 2, # stereo
864-
"resolution": 256,
865-
"z_channels": 8,
866-
"double_z": True,
867-
"attn_type": "vanilla",
868-
"mid_block_add_attention": False, # Based on metadata: false
864+
"mid_block_add_attention": False,
869865
"norm_type": "pixel",
870-
"causality_axis": "height", # Based on metadata
871-
"mel_bins": 64, # Based on metadata: mel_bins = 64
872-
}
873-
874-
decoder_config = {
875-
# Inherits encoder config, can override specific params
876-
**encoder_config,
877-
"out_ch": 2, # Stereo audio output (2 channels)
878-
"give_pre_end": False,
879-
"tanh_out": False,
866+
"causality_axis": "height",
880867
}
881868

882869
config = {
883-
"_class_name": "CausalAudioAutoencoder",
884-
"sampling_rate": 16000,
885870
"model": {
886871
"params": {
887-
"encoder": encoder_config,
888-
"decoder": decoder_config,
872+
"ddconfig": ddconfig,
873+
"sampling_rate": 16000,
889874
}
890875
},
876+
"preprocessing": {
877+
"stft": {
878+
"filter_length": 1024,
879+
"hop_length": 160,
880+
},
881+
},
891882
}
892883

893884
return config

comfy/ldm/lightricks/vae/causal_video_autoencoder.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
ops = comfy.ops.disable_weight_init
1717

18+
def in_meta_context():
19+
return torch.device("meta") == torch.empty(0).device
20+
1821
def mark_conv3d_ended(module):
1922
tid = threading.get_ident()
2023
for _, m in module.named_modules():
@@ -350,6 +353,10 @@ def __init__(
350353
output_channel = output_channel * block_params.get("multiplier", 2)
351354
if block_name == "compress_all":
352355
output_channel = output_channel * block_params.get("multiplier", 1)
356+
if block_name == "compress_space":
357+
output_channel = output_channel * block_params.get("multiplier", 1)
358+
if block_name == "compress_time":
359+
output_channel = output_channel * block_params.get("multiplier", 1)
353360

354361
self.conv_in = make_conv_nd(
355362
dims,
@@ -395,17 +402,21 @@ def __init__(
395402
spatial_padding_mode=spatial_padding_mode,
396403
)
397404
elif block_name == "compress_time":
405+
output_channel = output_channel // block_params.get("multiplier", 1)
398406
block = DepthToSpaceUpsample(
399407
dims=dims,
400408
in_channels=input_channel,
401409
stride=(2, 1, 1),
410+
out_channels_reduction_factor=block_params.get("multiplier", 1),
402411
spatial_padding_mode=spatial_padding_mode,
403412
)
404413
elif block_name == "compress_space":
414+
output_channel = output_channel // block_params.get("multiplier", 1)
405415
block = DepthToSpaceUpsample(
406416
dims=dims,
407417
in_channels=input_channel,
408418
stride=(1, 2, 2),
419+
out_channels_reduction_factor=block_params.get("multiplier", 1),
409420
spatial_padding_mode=spatial_padding_mode,
410421
)
411422
elif block_name == "compress_all":
@@ -455,6 +466,15 @@ def __init__(
455466
output_channel * 2, 0, operations=ops,
456467
)
457468
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
469+
else:
470+
self.register_buffer(
471+
"last_scale_shift_table",
472+
torch.tensor(
473+
[0.0, 0.0],
474+
device="cpu" if in_meta_context() else None
475+
).unsqueeze(1).expand(2, output_channel),
476+
persistent=False,
477+
)
458478

459479

460480
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
@@ -883,6 +903,15 @@ def __init__(
883903
self.scale_shift_table = nn.Parameter(
884904
torch.randn(4, in_channels) / in_channels**0.5
885905
)
906+
else:
907+
self.register_buffer(
908+
"scale_shift_table",
909+
torch.tensor(
910+
[0.0, 0.0, 0.0, 0.0],
911+
device="cpu" if in_meta_context() else None
912+
).unsqueeze(1).expand(4, in_channels),
913+
persistent=False,
914+
)
886915

887916
self.temporal_cache_state={}
888917

@@ -1012,9 +1041,6 @@ def __init__(self):
10121041
super().__init__()
10131042
self.register_buffer("std-of-means", torch.empty(128))
10141043
self.register_buffer("mean-of-means", torch.empty(128))
1015-
self.register_buffer("mean-of-stds", torch.empty(128))
1016-
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
1017-
self.register_buffer("channel", torch.empty(128))
10181044

10191045
def un_normalize(self, x):
10201046
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
@@ -1027,9 +1053,12 @@ def __init__(self, version=0, config=None):
10271053
super().__init__()
10281054

10291055
if config is None:
1030-
config = self.guess_config(version)
1056+
config = self.get_default_config(version)
10311057

1058+
self.config = config
10321059
self.timestep_conditioning = config.get("timestep_conditioning", False)
1060+
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
1061+
self.decode_timestep = config.get("decode_timestep", 0.05)
10331062
double_z = config.get("double_z", True)
10341063
latent_log_var = config.get(
10351064
"latent_log_var", "per_channel" if double_z else "none"
@@ -1044,13 +1073,15 @@ def __init__(self, version=0, config=None):
10441073
latent_log_var=latent_log_var,
10451074
norm_layer=config.get("norm_layer", "group_norm"),
10461075
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
1076+
base_channels=config.get("encoder_base_channels", 128),
10471077
)
10481078

10491079
self.decoder = Decoder(
10501080
dims=config["dims"],
10511081
in_channels=config["latent_channels"],
10521082
out_channels=config.get("out_channels", 3),
10531083
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
1084+
base_channels=config.get("decoder_base_channels", 128),
10541085
patch_size=config.get("patch_size", 1),
10551086
norm_layer=config.get("norm_layer", "group_norm"),
10561087
causal=config.get("causal_decoder", False),
@@ -1060,7 +1091,7 @@ def __init__(self, version=0, config=None):
10601091

10611092
self.per_channel_statistics = processor()
10621093

1063-
def guess_config(self, version):
1094+
def get_default_config(self, version):
10641095
if version == 0:
10651096
config = {
10661097
"_class_name": "CausalVideoAutoencoder",
@@ -1167,8 +1198,7 @@ def encode(self, x):
11671198
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
11681199
return self.per_channel_statistics.normalize(means)
11691200

1170-
def decode(self, x, timestep=0.05, noise_scale=0.025):
1201+
def decode(self, x):
11711202
if self.timestep_conditioning: #TODO: seed
1172-
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
1173-
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
1174-
1203+
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
1204+
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)

0 commit comments

Comments
 (0)