A diffusion model training framework for fine-tuning image and video models.
Status: pre-alpha 0.055 -- this is early, expect rough edges and breaking changes.
Serenity is a Python tool for training LoRAs, LyCORIS adapters, embeddings, and full fine-tunes across a bunch of diffusion model architectures. It grew out of wanting a single config-driven pipeline that could handle everything from SD 1.5 to video models without switching tools.
It's not polished software yet. The config surface is large (150+ fields), some model support is more battle-tested than others, and the documentation is catching up to the code. But it works, and if you're comfortable reading config files and don't mind the occasional rough edge, it might be useful to you.
These have actual model adapter code in the repo:
- Stable Diffusion 1.5 -- including inpainting
- SDXL 1.0 -- base and inpainting
- Stable Diffusion 3 / 3.5 -- flow-matching
- Flux 1 -- Dev, Schnell, Fill
- Flux 2 -- including Klein 4B and 9B
- Chroma
- Z-Image
- LTX2 -- video
- HunyuanVideo -- video
- Qwen -- image generation and editing
Not all models are equally tested. SD 1.5, SDXL, and Flux get the most use. Some of the newer ones (Chroma, HunyuanVideo) have less mileage on them.
- Builder: Implements all changes
- Bug Fixer: Reviews each phase for regressions
- Skeptic: Verifies parity with Lightricks reference
Audit of Lightricks' official LTX-2 trainer (github.com/Lightricks/LTX-2/packages/ltx-trainer) revealed several gaps in Serenity's LTX2 training path. This prompt fixes them in order of priority.
Key finding: Timestep encoding is NOT a bug — Serenity pre-multiplies sigma * 1000 before the sinusoidal embedding, while Lightricks passes raw sigma and their ltx-core model internally multiplies by timestep_scale_multiplier=1000. Same value reaches the embedding. Verified.
Repos: Serenity at /home/alex/serenity, branch experimental/ui.
Problem: Full finetune of the 22B transformer produces a ~46GB state dict. Saving as a single safetensors file is slow and fragile. Resume only handles single files.
Files to modify:
serenity/checkpoint/— find the checkpoint save function used by the training loopserenity/cli/native_diffusion.py— or wherever checkpoint resume is triggered
Implementation:
from safetensors.torch import save_model
def save_sharded_checkpoint(model, save_dir: Path, step: int, max_shard_size: str = "5GB"):
"""Save model weights as sharded safetensors with index."""
save_dir.mkdir(parents=True, exist_ok=True)
prefix = f"model_weights_step_{step:05d}"
save_model(
model,
str(save_dir / f"{prefix}.safetensors"),
max_shard_size=max_shard_size,
)
# safetensors.torch.save_model automatically creates:
# - {prefix}.safetensors (if single shard fits)
# - {prefix}-00001-of-NNNNN.safetensors + {prefix}.safetensors.index.json (if sharded)Add detection in the checkpoint loading path:
def load_checkpoint(path: Path):
if path.is_dir():
# Check for sharded index
index_files = list(path.glob("*.safetensors.index.json"))
if index_files:
return load_sharded_safetensors(str(path)) # already exists in ltx2_checkpoint.py
# Fallback: look for single file
safetensors_files = sorted(path.glob("*step_*.safetensors"))
if safetensors_files:
return load_file(str(safetensors_files[-1]))
elif path.is_file() and path.suffix == ".safetensors":
return load_file(str(path))
raise FileNotFoundError(f"No checkpoint found at {path}")Verification: Save a dummy 1GB model sharded at 256MB, verify 4+ shard files + index.json are created. Load it back and compare state dicts.
Problem: LoRA training the 22B model on 24GB requires quantizing the base. Loading the full model to GPU for quantization OOMs. Lightricks solved this with per-block GPU quantization.
File to create: serenity/training/block_quantize.py
Reference: Lightricks/LTX-2/packages/ltx-trainer/src/ltx_trainer/quantization.py
Implementation:
"""Block-by-block quanto quantization for training.
Quantizes transformer blocks one at a time on GPU then moves back to CPU.
Peak VRAM = one transformer block (~1GB) instead of full model (~46GB).
"""
from optimum.quanto import freeze, quantize
# Modules to exclude from quantization
EXCLUDE_PATTERNS = [
"patchify_proj", "audio_patchify_proj",
"proj_out", "audio_proj_out",
"*adaln*", "time_proj", "timestep_embedder*",
"caption_projection*", "audio_caption_projection*",
"*norm*",
]
SKIP_ROOT_MODULES = {
"patchify_proj", "audio_patchify_proj",
"proj_out", "audio_proj_out",
"audio_caption_projection",
}
def quantize_for_training(model, precision="int8-quanto", device="cuda"):
"""Quantize model block-by-block for LoRA training."""
weight_quant = _get_quanto_dtype(precision)
original_device = next(model.parameters()).device
original_dtype = next(model.parameters()).dtype
if hasattr(model, "transformer_blocks"):
# Block-by-block quantization
for i, block in enumerate(model.transformer_blocks):
block.to(device, dtype=original_dtype)
quantize(block, weights=weight_quant, exclude=EXCLUDE_PATTERNS)
freeze(block)
block.to("cpu")
print(f" Quantized block {i+1}/{len(model.transformer_blocks)}")
# Quantize remaining non-block modules
for name, module in model.named_children():
if name == "transformer_blocks" or name in SKIP_ROOT_MODULES:
continue
module.to(device, dtype=original_dtype)
quantize(module, weights=weight_quant, exclude=EXCLUDE_PATTERNS)
freeze(module)
module.to("cpu")
else:
model.to(device)
quantize(model, weights=weight_quant, exclude=EXCLUDE_PATTERNS)
freeze(model)
model.to(original_device)
return modelWire it up: In the LTX2 training config, add quantization: str | None = None. In the training setup, call quantize_for_training when set.
Exclude patterns are from Lightricks — these are the layers that shouldn't be quantized (input/output projections, timestep embeddings, norms). Copy them exactly.
Verification: Load LTX2 transformer to CPU, quantize with this function, confirm VRAM never exceeds 3GB during quantization. Run a single training step with LoRA on the quantized model.
Problem: Lightricks V2V strategy supports downscaled reference videos for IC-LoRA. Reference positions are scaled up to match target coordinate space. Serenity doesn't handle this — if ref and target have different spatial dims, RoPE positions will be wrong.
File to modify: serenity/cli/diffusion_losses.py — the family == "ltx2" I2V branch
Implementation:
In the I2V conditioning section (around line 1086+), after computing ref_coords and target_coords:
# Infer reference downscale factor
if ref_height != latent_height or ref_width != latent_width:
if latent_height % ref_height != 0 or latent_width % ref_width != 0:
raise ValueError(
f"Target dims ({latent_height}x{latent_width}) must be exact multiples "
f"of ref dims ({ref_height}x{ref_width})"
)
scale_h = latent_height // ref_height
scale_w = latent_width // ref_width
if scale_h != scale_w:
raise ValueError(f"Non-uniform scale: h={scale_h}, w={scale_w}")
# Scale ref positions to target coordinate space
ref_coords = ref_coords.clone()
ref_coords[:, 1, ...] *= scale_h # height axis
ref_coords[:, 2, ...] *= scale_h # width axisAlso: Save reference_downscale_factor in checkpoint metadata when saving IC-LoRA weights (follow Lightricks pattern).
Verification: Train IC-LoRA with ref at half resolution of target. Verify positions are scaled. Compare with Lightricks output.
Problem: Serenity's _load_text_encoder in ltx2.py catches all exceptions and silently returns (None, None). On a 3090, loading Gemma3 12B BF16 will OOM (~24GB). Should offer 8-bit fallback instead.
File to modify: serenity/models/ltx2.py — _load_text_encoder method (~line 568)
Implementation:
def _load_text_encoder(
self,
model_name_or_path: str,
dtype: torch.dtype,
load_in_8bit: bool = False,
) -> tuple[Any, Any]:
"""Load Gemma3 text encoder and tokenizer.
Args:
load_in_8bit: Use bitsandbytes 8-bit quantization to fit on 24GB.
"""
from transformers import AutoTokenizer
local_only = Path(model_name_or_path).exists()
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
use_fast=True,
local_files_only=local_only,
)
if load_in_8bit:
try:
from transformers import BitsAndBytesConfig, Gemma3ForConditionalGeneration
except ImportError:
raise ImportError(
"8-bit text encoder requires bitsandbytes: pip install bitsandbytes"
)
print(f"[ltx2] loading Gemma3 text encoder in 8-bit from {model_name_or_path}")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
model_name_or_path,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="auto",
local_files_only=local_only,
)
else:
from transformers import Gemma3ForConditionalGeneration
print(f"[ltx2] loading Gemma3 text encoder from {model_name_or_path}")
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
local_files_only=local_only,
)
text_encoder.requires_grad_(False)
return text_encoder, tokenizerAlso: Thread load_in_8bit from the training config down through load_pipeline → _load_text_encoder. Add text_encoder_8bit: bool = False to the LTX2 training config.
Remove: The suppress(Exception) wrapper on audio VAE and vocoder loading too (lines 341-349, 352-359 in ltx2.py). Replace with specific except (KeyError, RuntimeError) that logs the actual error.
Verification: Load with load_in_8bit=True on a 24GB card — should succeed with ~8GB VRAM for text encoder. Load without 8-bit on a card with <24GB free — should raise a clear OOM error, not silently return None.
Problem: Serenity's LTXVideoTransformer already has self.gradient_checkpointing and wired up torch.utils.checkpoint.checkpoint calls in the forward pass (confirmed in the code). But the training config doesn't expose a toggle to enable it.
File to modify: Training config for LTX2.
Implementation: Add enable_gradient_checkpointing: bool = False to the LTX2 training config. In the training setup, call:
if config.get("enable_gradient_checkpointing", False):
pipeline.transformer.gradient_checkpointing = True
pipeline.transformer.enable_gradient_checkpointing()This is complementary to Stagehand — use gradient checkpointing for blocks that ARE on GPU while Stagehand swaps the rest.
Verification: Enable gradient checkpointing, run training step, verify peak VRAM is lower than without.
File: serenity/training/ltx2/trainer.py — currently 19 lines, just a docstring saying "all logic lives in diffusion_losses.py".
Action: Delete the file. Update serenity/training/ltx2/__init__.py to remove any import of it. The docstring information is now captured in the utils.py module docstring and the audit docs.
Verification: grep -r "from serenity.training.ltx2.trainer" serenity/ returns nothing.
Problem: Serenity and Lightricks compute video/audio RoPE coordinates independently. Current audit confirms they're mathematically identical for all three RoPE paths:
- Video self-attn: same scale factors
(8,32,32), same causal offset+1-8, same fps division, same meshgrid - Audio self-attn: same
audio_latent_downsample_factor=4, same causal formula(mel + 1 - 4).clamp(0), samehop_length=160 / sampling_rate=16000 - Cross-attn (a2v/v2a): both slice
coords[:, 0:1, :]— temporal only
But there's no automated test to catch future drift. Also: ltx-core orders scale factors as (time, width, height) while Serenity uses (time, height, width) — harmless now since both spatial factors are 32, but a latent bug.
File to create: serenity/tests/test_ltx2_rope_parity.py
Implementation:
"""Verify Serenity's LTX2 RoPE coords match the Lightricks reference implementation."""
import torch
import pytest
# ── Video reference ──
def _lightricks_reference_video_coords(
batch_size, num_frames, height, width,
patch_size_t=1, patch_size=1,
scale_factors=(8, 32, 32), # (time, height, width) — Serenity order
causal_offset=1,
fps=25.0,
device="cpu",
):
"""Reimplementation of ltx-core's get_patch_grid_bounds + get_pixel_coords(causal_fix=True) + fps scaling.
This is a standalone reference — does NOT import ltx-core.
"""
grid_f = torch.arange(0, num_frames, patch_size_t, dtype=torch.float32, device=device)
grid_h = torch.arange(0, height, patch_size, dtype=torch.float32, device=device)
grid_w = torch.arange(0, width, patch_size, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
patch_starts = torch.stack(grid, dim=0) # [3, F, H, W]
patch_size_delta = torch.tensor([patch_size_t, patch_size, patch_size],
dtype=torch.float32, device=device).view(3, 1, 1, 1)
patch_ends = patch_starts + patch_size_delta
latent_coords = torch.stack([patch_starts, patch_ends], dim=-1) # [3, F, H, W, 2]
latent_coords = latent_coords.flatten(1, 3).unsqueeze(0).expand(batch_size, -1, -1, -1)
# Scale to pixel space
scale_tensor = torch.tensor(scale_factors, device=device).view(1, -1, 1, 1)
pixel_coords = latent_coords * scale_tensor
# Causal fix: temporal axis
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + causal_offset - scale_factors[0]).clamp(min=0)
# FPS scaling
pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps
return pixel_coords
# ── Audio reference ──
def _lightricks_reference_audio_coords(
batch_size, num_frames,
audio_latent_downsample_factor=4,
causal_offset=1,
hop_length=160,
sampling_rate=16000,
patch_size_t=1,
shift=0,
device="cpu",
):
"""Reimplementation of ltx-core AudioPatchifier._compute_audio_timings.
Standalone reference — does NOT import ltx-core.
Returns [B, 1, num_patches, 2] in seconds.
"""
def _latent_time_in_sec(start, end):
frame = torch.arange(start, end, dtype=torch.float32, device=device)
mel = frame * audio_latent_downsample_factor
mel = (mel + causal_offset - audio_latent_downsample_factor).clamp(min=0)
return mel * hop_length / sampling_rate
start_s = _latent_time_in_sec(shift, num_frames + shift)
end_s = _latent_time_in_sec(shift + patch_size_t, num_frames + shift + patch_size_t)
start_s = start_s.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) # [B, 1, T]
end_s = end_s.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) # [B, 1, T]
return torch.stack([start_s, end_s], dim=-1) # [B, 1, T, 2]
# ── Video tests ──
@pytest.mark.parametrize("num_frames,height,width,fps", [
(5, 30, 40, 25.0), # typical small video
(1, 64, 64, 25.0), # single frame (image)
(17, 48, 64, 24.0), # odd frame count
(33, 30, 52, 30.0), # non-standard fps
])
def test_video_coords_parity(num_frames, height, width, fps):
"""Serenity's prepare_video_coords must match the Lightricks reference."""
from serenity.models.ltx2_transformer import LTXRotaryEmbedding
batch_size = 2
rope = LTXRotaryEmbedding(
dim=4096,
modality="video",
patch_size=1,
patch_size_t=1,
scale_factors=(8, 32, 32),
causal_offset=1,
)
serenity_coords = rope.prepare_video_coords(
batch_size, num_frames, height, width, device="cpu", fps=fps,
)
reference_coords = _lightricks_reference_video_coords(
batch_size, num_frames, height, width, fps=fps,
)
assert serenity_coords.shape == reference_coords.shape, (
f"Shape mismatch: {serenity_coords.shape} vs {reference_coords.shape}"
)
torch.testing.assert_close(serenity_coords, reference_coords, atol=1e-6, rtol=1e-6)
def test_single_frame_causal_clamp():
"""First frame temporal coord must clamp to 0, not go negative."""
from serenity.models.ltx2_transformer import LTXRotaryEmbedding
rope = LTXRotaryEmbedding(
dim=4096, modality="video", patch_size=1, patch_size_t=1,
scale_factors=(8, 32, 32), causal_offset=1,
)
coords = rope.prepare_video_coords(1, 1, 16, 16, device="cpu", fps=25.0)
assert (coords[:, 0, :, :] >= 0).all()
# ── Audio tests ──
@pytest.mark.parametrize("num_frames,shift", [
(25, 0), # 1 second of audio at 25 latent fps
(1, 0), # single audio frame
(50, 0), # 2 seconds
(25, 10), # with shift (continuation)
])
def test_audio_coords_parity(num_frames, shift):
"""Serenity's prepare_audio_coords must match the Lightricks reference."""
from serenity.models.ltx2_transformer import LTXRotaryEmbedding
batch_size = 2
rope = LTXRotaryEmbedding(
dim=1024,
modality="audio",
patch_size=1,
patch_size_t=1,
scale_factors=[4],
causal_offset=1,
sampling_rate=16000,
hop_length=160,
)
serenity_coords = rope.prepare_audio_coords(
batch_size, num_frames, device="cpu", shift=shift,
)
reference_coords = _lightricks_reference_audio_coords(
batch_size, num_frames, shift=shift,
)
assert serenity_coords.shape == reference_coords.shape, (
f"Shape mismatch: {serenity_coords.shape} vs {reference_coords.shape}"
)
torch.testing.assert_close(serenity_coords, reference_coords, atol=1e-6, rtol=1e-6)
def test_audio_first_frame_causal_clamp():
"""First audio frame temporal coord must clamp to 0."""
from serenity.models.ltx2_transformer import LTXRotaryEmbedding
rope = LTXRotaryEmbedding(
dim=1024, modality="audio", patch_size=1, patch_size_t=1,
scale_factors=[4], causal_offset=1, sampling_rate=16000, hop_length=160,
)
coords = rope.prepare_audio_coords(1, 1, device="cpu")
assert (coords >= 0).all()
# ── Cross-attention RoPE tests ──
def test_cross_attn_uses_temporal_only():
"""Cross-attention RoPE should use only temporal dimension of coords.
Both Serenity and Lightricks slice video_coords[:, 0:1, :] for cross-attn.
Verify the slice produces 1D temporal coords, not full 3D.
"""
from serenity.models.ltx2_transformer import LTXRotaryEmbedding
rope = LTXRotaryEmbedding(
dim=4096, modality="video", patch_size=1, patch_size_t=1,
scale_factors=(8, 32, 32), causal_offset=1,
)
coords = rope.prepare_video_coords(1, 5, 16, 16, device="cpu", fps=25.0)
# Full coords: [B, 3, num_patches, 2]
assert coords.shape[1] == 3
# Cross-attn slice: [B, 1, num_patches, 2] — temporal only
ca_coords = coords[:, 0:1, :]
assert ca_coords.shape[1] == 1
def test_rope_output_parity():
"""Full RoPE (cos, sin) output must be deterministic and match between runs."""
from serenity.models.ltx2_transformer import LTXRotaryEmbedding
rope = LTXRotaryEmbedding(
dim=4096, modality="video", patch_size=1, patch_size_t=1,
scale_factors=(8, 32, 32), causal_offset=1,
)
coords = rope.prepare_video_coords(1, 5, 16, 16, device="cpu", fps=25.0)
cos1, sin1 = rope(coords, device="cpu")
cos2, sin2 = rope(coords, device="cpu")
torch.testing.assert_close(cos1, cos2)
torch.testing.assert_close(sin1, sin2)Key thing: Both _lightricks_reference_video_coords and _lightricks_reference_audio_coords are standalone reimplementations of Lightricks' logic (no ltx-core import). If Serenity's output ever diverges from these, the tests catch it. Audio parity covers the causal offset, hop_length/sampling_rate conversion, and the shift parameter for continuation segments.
Verification: pytest serenity/tests/test_ltx2_rope_parity.py -v — all 12+ tests pass.
After each phase, run:
cd /home/alex/serenity
python -c "from serenity.models.ltx2 import LTX2Model; print('LTX2Model imports OK')"
python -m pytest serenity/tests/test_ltx2_*.py -x -v 2>&1 | tail -20-
torch.compile per-block for training — good optimization but not correctness-critical. Save for a perf pass.
-
EmbeddingsProcessor vs process_text_through_connectors — functionally equivalent based on code review. Verify numerically if results seem off.
-
ltx-core scale factor ordering —
(time, width, height)vs Serenity's(time, height, width). Both spatial factors are 32, so it's a no-op. Log as known divergence, don't fix unless asymmetric VAE scaling appears. -
LoRA -- configurable rank, alpha, layer filtering
-
LyCORIS -- LoKR, LoHa, Tucker decomposition, DoRA, RS-LoRA
-
Embedding / Textual Inversion -- SD 1.5 and SDXL only
-
Full Fine-Tune -- needs 24 GB+ VRAM
-
VAE Fine-Tune
- 45+ optimizers (standard, 8-bit, adaptive LR, schedule-free, research)
- Multiple LR schedulers, loss functions, noise strategies
- Aspect-ratio bucketing and disk-based latent/text-encoder caching
- Gradient checkpointing, layer offloading, quantization (NF4, INT8, FP8)
- DDP multi-GPU support
- EMA weight averaging
- Mixed precision with stochastic rounding
- Sample generation during training
- YAML and JSON configs with automatic enum coercion
- Checkpoint saving in SafeTensors, Diffusers, CKPT, ComfyUI formats
- Resume from backup
git clone https://github.com/CodeAlexx/Serenity.git
cd Serenity
python -m venv venv
source venv/bin/activate
pip install -r requirements.txtNeeds Python 3.10+, PyTorch 2.0+ with CUDA, and an NVIDIA GPU (8 GB minimum for LoRA, 24 GB+ for full fine-tune).
# config.yaml
model_type: sdxl
training_method: lora
transformer_path: "stabilityai/stable-diffusion-xl-base-1.0"
output_dir: "output/"
output_model_destination: "output/my_lora.safetensors"
learning_rate: 1.0e-4
epochs: 10
batch_size: 1
resolution: "1024"
train_dtype: BFLOAT_16
lora_rank: 16
lora_alpha: 1.0
optimizer:
optimizer: ADAMW
concepts:
- name: "my_subject"
path: "training_data/my_subject"
prompt:
source: "txt"python -m serenity.cli.commands train config.yamlSee docs/USER_GUIDE.md for the full configuration reference, docs/FEATURES.md for the complete feature reference, and docs/ERIQUANT.md for the quantization guide.
serenity/
├── core/ # Config, interfaces, trainer loop
├── models/ # Model adapters (30+)
├── training/ # Optimizers, schedulers, losses, EMA, precision, distributed
├── pipeline/ # Dataset loading, bucketing, caching, augmentations, masks
├── data/ # Dataloader
├── checkpoint/ # Save, load, resume, format conversion
├── memory/ # Layer offloading, memory prediction
├── stagehand/ # Block-swapping GPU memory orchestration
├── sampling/ # Sample generation during training
├── adapters/ # LoRA, LyCORIS adapter layer
├── presets/ # VRAM-tier preset configs
├── cli/ # Command-line interface
├── ui/ # Desktop TUI application
├── utils/ # Utilities
└── tests/ # Tests (1900+)
eriquant/ # EriQuant quantization toolkit
docs/ # Documentation
| Model | LoRA | Full FT | Embedding | Inpainting |
|---|---|---|---|---|
| SD 1.5 | Yes | Yes | Yes | Yes |
| SDXL 1.0 | Yes | Yes | Yes | Yes |
| SD 3 / 3.5 | Yes | Yes | -- | -- |
| Flux 1 | Yes | Yes | -- | Yes (Fill) |
| Flux 2 | Yes | Yes | -- | -- |
| Flux 2 Klein | Yes | Yes | -- | -- |
| Chroma | Yes | Yes | -- | -- |
| Z-Image | Yes | Yes | -- | -- |
| LTX2 | Yes | Yes | -- | -- |
| HunyuanVideo | Yes | Yes | -- | -- |
| Qwen | Yes | Yes | -- | -- |
- Pre-alpha software. APIs and config schema will change.
- No GUI -- CLI and config files only.
- AMD/ROCm support is untested and probably broken.
- Some model/training-method combinations may have bugs that haven't been found yet.
- Documentation may lag behind actual code behavior.
Contributions welcome. This is pre-alpha, so expect things to move around. Open an issue before starting large changes so we can discuss the approach.
See LICENSE file.
Built on PyTorch, Diffusers,StageHand, SquaredQ, SerenityBoard, LyCORIS, and many other open-source libraries.