Live demo: http://34.34.91.217:8080
A 2-node Decentralized Diffusion Model (DDM) demo on MNIST flow-matching:
- Trainium expert (AWS
trn1.2xlarge) — digits 0-4 - GPU expert (NVIDIA T4 or larger) — digits 5-9
- CPU router — a small oracle MLP that picks the right expert per denoising step
Both experts use an identical SmallConditionalUNet (6.5M params, 3-level encoder with channel progression 64 → 128 → 256). Same Euler integration, same denoising-step count on both.
| Expert | Path | Digits | Framework |
|---|---|---|---|
| Trainium | outputs/checkpoints/expert_trn1_0_4_v3/epoch_*.pt |
0,1,2,3,4 | PyTorch + torch_xla |
| GPU | outputs/checkpoints/expert_gpu_5_9_v3/step_*/ |
5,6,7,8,9 | JAX / Flax |
| Router | outputs/router_checkpoints/router_mnist_oracle_2way/step_*/ |
oracle | JAX / Flax |
Clone and serve — no training required for the demo to work.
git clone https://github.com/bageldotcom/trainiumgpu.git
cd trainiumgpu
pip install -e .Then follow docs/DEPLOYMENT_GUIDE.md to stand up the three processes
(Trainium expert, GPU expert, router+UI) on their respective hosts.
Browser ──► Router (CPU host, FastAPI + SSE)
│
├──► Trainium expert (digits 0-4) [HTTP :8000]
└──► GPU expert (digits 5-9) [HTTP :8000]
Each step of the reverse diffusion trajectory is:
- Router receives current latent
x_t, timestept, and class labely. - A learned router MLP (or oracle for demo purposes) picks the correct
expert for this
y. - The chosen expert returns its velocity prediction
v(x_t, t, y). - Router advances:
x_t = x_t + dt · v.
The wire protocol is a compact binary .npz over HTTP POST.
SmallConditionalUNet, 6.5M params:
- Sinusoidal time embedding (128-dim) + learned class embedding (128-dim)
- Encoder: 3 levels, channels 64 → 128 → 256 at resolutions 32 → 16 → 8 → 4
- Bottleneck: 2× ResBlock at 256 channels, 4×4 spatial
- Decoder: 3 levels mirroring encoder, with skip concatenations
- Flow-matching objective:
MSE(v_pred, x0 − noise)
.
├── src/trainiumgpu/
│ ├── experts/
│ │ ├── model_torch.py # PyTorch UNet (Trainium)
│ │ ├── model.py # Flax UNet (GPU)
│ │ └── train.py # Shared training loop (JAX)
│ ├── router/
│ │ ├── model.py # RouterMLP
│ │ ├── train.py # Router training
│ │ ├── inference.py # Router checkpoint loader
│ │ └── expert_client.py # HTTP client that calls an expert
│ ├── serving/protocol.py # Binary npz wire protocol
│ └── demo/
│ ├── app.py # FastAPI + SSE streaming
│ └── static/ # HTML + JS + world-map SVG
├── scripts/
│ ├── serve_expert_torch_neuron.py # Trainium serve
│ ├── serve_expert_mnist.py # GPU (JAX) serve
│ ├── train_expert_torch_neuron.py # Trainium training
│ ├── train_expert_mnist.py # GPU training
│ ├── train_router_mnist.py
│ └── run_router_demo.py
├── outputs/
│ ├── checkpoints/ # expert checkpoints
│ └── router_checkpoints/ # router checkpoints
└── docs/
├── DEPLOYMENT_GUIDE.md # end-to-end deploy
└── TRAINIUM_TRAINING_GUIDE.md # Neuron compiler pitfalls + fixes
A single T2I / class-conditional model has to learn every subdomain of its training set. A DDM splits the data by semantic cluster and gives each cluster its own expert. At inference time a lightweight router picks the right expert per denoising step. In this MNIST toy:
- Expert 0 only ever sees digits 0-4 during training.
- Expert 1 only ever sees digits 5-9.
- Neither expert can generate a digit outside its half; the router makes sure that's never asked of it.
This scales up to real image datasets where each expert can specialize to
a semantic cluster (faces, landscapes, text, etc.) — see
arxiv.org/abs/2507.05300.