Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ venv/
ENV/
env.bak/
venv.bak/
.vscode/

# Spyder project settings
.spyderproject
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ Here is a To-Do list, feel welcome to help to any point along this list. The alr
- [ ] Train our models on toy datasets for different tasks (conditional generation, Image to Image ...)
- [ ] Add possibility to train LORA/DORA
- [x] add different sampler
- [ ] Add SPRINT (https://arxiv.org/pdf/2510.21986)
- [x] Add SPRINT (https://arxiv.org/pdf/2510.21986)
3 changes: 3 additions & 0 deletions configs/embedder/precomputed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: diffulab.networks.PrecomputedEmbedder
path_null_embedding: "/path/to/null_embedding.pt"
null_embedding_seq_len: 7
4 changes: 1 addition & 3 deletions configs/model/ddt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ _target_: diffulab.networks.DDT
simple_ddt: true
input_channels: 3
output_channels: 3
input_dim: 512
hidden_dim: 512
context_dim: 512
inner_dim: 512
num_heads: 8
mlp_ratio: 4
patch_size: 2
Expand Down
3 changes: 1 addition & 2 deletions configs/model/dit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ _target_: diffulab.networks.MMDiT
simple_dit: true
input_channels: 3
output_channels: 3
input_dim: 512
hidden_dim: 512
inner_dim: 512
embedding_dim: 512
num_heads: 8
mlp_ratio: 4
Expand Down
16 changes: 16 additions & 0 deletions configs/model/sprint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: diffulab.networks.SprintDiT
simple_dit: true
input_channels: 3
output_channels: 3
inner_dim: 512
embedding_dim: 512
num_heads: 8
mlp_ratio: 4
patch_size: 2
encoder_depth: 2
deep_layers_depth: 8
decoder_depth: 2
n_classes: 10
classifier_free: false
use_checkpoint: false
drop_rate: 0.75
5 changes: 2 additions & 3 deletions configs/train_imagenet_flow_matching_repa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ trainer:
project_name: imagenet_repa_flow_matching
n_epoch: 150
precision_type: "bf16"
p_classifier_free_guidance: 0.2
p_classifier_free_guidance: 0.1
ema_update_after_step: 0
ema_update_every: 10

model:
input_channels: 32
output_channels: 32
input_dim: 768
hidden_dim: 768
inner_dim: 768
embedding_dim: 256
num_heads: 12
mlp_ratio: 4
Expand Down
6 changes: 2 additions & 4 deletions configs/train_imagenet_repa_txt_to_img.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ trainer:
project_name: txt_to_img_imagenet_repa_flux2
n_epoch: 50
precision_type: "bf16"
p_classifier_free_guidance: 0.2
p_classifier_free_guidance: 0.1
gradient_accumulation_step: 16
val_step_shift: 6.93
compile: false
Expand All @@ -28,9 +28,7 @@ trainer:
model:
input_channels: 128
output_channels: 128
input_dim: 640
hidden_dim: 640
context_dim: 640
inner_dim: 640
num_heads: 10
mlp_ratio: 4
patch_size: 1
Expand Down
62 changes: 62 additions & 0 deletions configs/train_imagenet_repa_txt_to_img_sprint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# configs/train_cifar10_flow_matching.yaml
# @package _global_
defaults:
- model: sprint
- diffuser: rectified_flow
- trainer: default
- dataset: imagenet_repa
- dataloader: default
- optimizer: adamw
- vision_tower: flux2
- embedder: precomputed
- _self_

# Override specific settings
trainer:
project_name: txt_to_img_imagenet_multiar_repa_flux2_sprint
n_epoch: 50
precision_type: "bf16"
p_classifier_free_guidance: 0.1
gradient_accumulation_step: 8
val_step_shift: 6.93
compile: true
ema_update_after_step: 0
ema_update_every: 1
ema_rate: 0.9999

model:
input_channels: 128
output_channels: 128
inner_dim: 768
embedding_dim: 768
num_heads: 12
mlp_ratio: 4
patch_size: 1
encoder_depth: 2
deep_layers_depth: 8
n_single_stream_blocks: 8
decoder_depth: 2
classifier_free: true
rope_base: 2000
rope_axes_dim: [16, 24, 24]
n_classes: null
simple_dit: false

diffuser:
extra_args:
logits_normal: true
shift: 4.63

dataloader:
batch_size: 64

# Hydra configuration
hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
sweep:
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
subdir: ${hydra.job.num}

perceiver_resampler:
use_resampler: false
4 changes: 4 additions & 0 deletions src/diffulab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
PerceiverResampler,
QwenTextEmbedder,
SD3TextEmbedder,
SmolVLMTextEmbedder,
SprintDiT,
UNetModel,
VisionTower,
)
Expand All @@ -32,9 +34,11 @@
"Flux2VAE",
"MMDiT",
"DDT",
"SprintDiT",
"PerceiverResampler",
"QwenTextEmbedder",
"SD3TextEmbedder",
"SmolVLMTextEmbedder",
"UNetModel",
"VisionTower",
"LossFunction",
Expand Down
7 changes: 5 additions & 2 deletions src/diffulab/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .denoisers import DDT, Denoiser, MMDiT, UNetModel
from .embedders import QwenTextEmbedder, SD3TextEmbedder
from .denoisers import DDT, Denoiser, MMDiT, SprintDiT, UNetModel
from .embedders import PrecomputedEmbedder, QwenTextEmbedder, SD3TextEmbedder, SmolVLMTextEmbedder
from .repa import REPA, DinoV2, PerceiverResampler
from .vision_towers import DCAE, Flux2VAE, VisionTower

Expand All @@ -8,8 +8,11 @@
"UNetModel",
"MMDiT",
"DDT",
"SprintDiT",
"SD3TextEmbedder",
"QwenTextEmbedder",
"SmolVLMTextEmbedder",
"PrecomputedEmbedder",
"REPA",
"DinoV2",
"PerceiverResampler",
Expand Down
3 changes: 2 additions & 1 deletion src/diffulab/networks/denoisers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .common import Denoiser
from .ddt import DDT
from .mmdit import MMDiT
from .sprint import SprintDiT
from .unet import UNetModel

__all__ = ["Denoiser", "UNetModel", "MMDiT", "DDT"]
__all__ = ["Denoiser", "UNetModel", "MMDiT", "DDT", "SprintDiT"]
Loading
Loading