Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions comfy/ldm/lightricks/av_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,33 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
additional_args["has_spatial_mask"] = has_spatial_mask

ax, a_latent_coords = self.a_patchifier.patchify(ax)

# Inject reference audio for ID-LoRA in-context conditioning
ref_audio = kwargs.get("ref_audio", None)
ref_audio_seq_len = 0
if ref_audio is not None:
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
if ref_tokens.shape[0] < ax.shape[0]:
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
ref_audio_seq_len = ref_tokens.shape[1]
B = ax.shape[0]

# Compute negative temporal positions matching ID-LoRA convention:
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
p = self.a_patchifier
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
time_offset = ref_end[-1].item() + tpl
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_pos = torch.stack([ref_start, ref_end], dim=-1)

additional_args["ref_audio_seq_len"] = ref_audio_seq_len
additional_args["target_audio_seq_len"] = ax.shape[1]
ax = torch.cat([ref_tokens, ax], dim=1)
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)

ax = self.audio_patchify_proj(ax)

# additional_args.update({"av_orig_shape": list(x.shape)})
Expand Down Expand Up @@ -721,6 +748,14 @@ def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):

# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0 and a_timestep is not None:
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
target_len = kwargs.get("target_audio_seq_len")
if a_timestep.dim() <= 1:
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
Expand Down Expand Up @@ -955,6 +990,13 @@ def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]

# Trim reference audio tokens before unpatchification
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0:
ax = ax[:, ref_audio_seq_len:]
if a_embedded_timestep.shape[1] > 1:
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]

# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()
Expand Down
4 changes: 4 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,10 @@ def extra_conds(self, **kwargs):
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)

ref_audio = kwargs.get("ref_audio", None)
if ref_audio is not None:
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)

return out

def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
Expand Down
80 changes: 80 additions & 0 deletions comfy_extras/nodes_lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import comfy.model_management
import comfy.model_sampling
import comfy.samplers
import comfy.utils
import math
import numpy as np
Expand Down Expand Up @@ -682,6 +683,84 @@ def execute(cls, av_latent) -> io.NodeOutput:
return io.NodeOutput(video_latent, audio_latent)


class LTXVReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVReferenceAudio",
display_name="LTXV Reference Audio (ID-LoRA)",
category="conditioning/audio",
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
],
outputs=[
io.Model.Output(),
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)

@classmethod
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
# Encode reference audio to latents and patchify
audio_latents = audio_vae.encode(reference_audio)
b, c, t, f = audio_latents.shape
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
ref_audio = {"tokens": ref_tokens}

positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})

# Patch model with identity guidance
m = model.clone()
scale = identity_guidance_scale
model_sampling = m.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)

def post_cfg_function(args):
if scale == 0:
return args["denoised"]

sigma = args["sigma"]
sigma_ = sigma[0].item()
if sigma_ > sigma_start or sigma_ < sigma_end:
return args["denoised"]

cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
model_options = args["model_options"].copy()
x = args["input"]

# Strip ref_audio from conditioning for the no-reference pass
noref_cond = []
for entry in cond:
new_entry = entry.copy()
mc = new_entry.get("model_conds", {}).copy()
mc.pop("ref_audio", None)
new_entry["model_conds"] = mc
noref_cond.append(new_entry)

(pred_noref,) = comfy.samplers.calc_cond_batch(
args["model"], [noref_cond], x, sigma, model_options
)

return cfg_result + (cond_pred - pred_noref) * scale

m.set_model_sampler_post_cfg_function(post_cfg_function)

return io.NodeOutput(m, positive, negative)


class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
Expand All @@ -697,6 +776,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
LTXVCropGuides,
LTXVConcatAVLatent,
LTXVSeparateAVLatent,
LTXVReferenceAudio,
]


Expand Down
2 changes: 1 addition & 1 deletion manager_requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
comfyui_manager==4.1b6
comfyui_manager==4.1b8
Loading