Skip to content

bageldotcom/trainiumgpu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Trainium + GPU MNIST DDM Demo

A 2-node Decentralized Diffusion Model (DDM) demo on MNIST flow-matching:

  • Trainium expert (AWS trn1.2xlarge) — digits 0-4
  • GPU expert (NVIDIA T4 or larger) — digits 5-9
  • CPU router — a small oracle MLP that picks the right expert per denoising step

Both experts use an identical SmallConditionalUNet (6.5M params, 3-level encoder with channel progression 64 → 128 → 256). Same Euler integration, same denoising-step count on both.

Pre-trained checkpoints (included)

Expert Path Digits Framework
Trainium outputs/checkpoints/expert_trn1_0_4_v3/epoch_*.pt 0,1,2,3,4 PyTorch + torch_xla
GPU outputs/checkpoints/expert_gpu_5_9_v3/step_*/ 5,6,7,8,9 JAX / Flax
Router outputs/router_checkpoints/router_mnist_oracle_2way/step_*/ oracle JAX / Flax

Clone and serve — no training required for the demo to work.

Quick start

git clone https://github.com/bageldotcom/trainiumgpu.git
cd trainiumgpu
pip install -e .

Then follow docs/DEPLOYMENT_GUIDE.md to stand up the three processes (Trainium expert, GPU expert, router+UI) on their respective hosts.

Architecture

Browser ──► Router (CPU host, FastAPI + SSE)
                │
                ├──► Trainium expert (digits 0-4)  [HTTP :8000]
                └──► GPU expert      (digits 5-9)  [HTTP :8000]

Each step of the reverse diffusion trajectory is:

  1. Router receives current latent x_t, timestep t, and class label y.
  2. A learned router MLP (or oracle for demo purposes) picks the correct expert for this y.
  3. The chosen expert returns its velocity prediction v(x_t, t, y).
  4. Router advances: x_t = x_t + dt · v.

The wire protocol is a compact binary .npz over HTTP POST.

Model

SmallConditionalUNet, 6.5M params:

  • Sinusoidal time embedding (128-dim) + learned class embedding (128-dim)
  • Encoder: 3 levels, channels 64 → 128 → 256 at resolutions 32 → 16 → 8 → 4
  • Bottleneck: 2× ResBlock at 256 channels, 4×4 spatial
  • Decoder: 3 levels mirroring encoder, with skip concatenations
  • Flow-matching objective: MSE(v_pred, x0 − noise)

Repo layout

.
├── src/trainiumgpu/
│   ├── experts/
│   │   ├── model_torch.py        # PyTorch UNet (Trainium)
│   │   ├── model.py              # Flax UNet (GPU)
│   │   └── train.py              # Shared training loop (JAX)
│   ├── router/
│   │   ├── model.py              # RouterMLP
│   │   ├── train.py              # Router training
│   │   ├── inference.py          # Router checkpoint loader
│   │   └── expert_client.py      # HTTP client that calls an expert
│   ├── serving/protocol.py       # Binary npz wire protocol
│   └── demo/
│       ├── app.py                # FastAPI + SSE streaming
│       └── static/               # HTML + JS + world-map SVG
├── scripts/
│   ├── serve_expert_torch_neuron.py   # Trainium serve
│   ├── serve_expert_mnist.py          # GPU (JAX) serve
│   ├── train_expert_torch_neuron.py   # Trainium training
│   ├── train_expert_mnist.py          # GPU training
│   ├── train_router_mnist.py
│   └── run_router_demo.py
├── outputs/
│   ├── checkpoints/                   # expert checkpoints
│   └── router_checkpoints/            # router checkpoints
└── docs/
    ├── DEPLOYMENT_GUIDE.md            # end-to-end deploy
    └── TRAINIUM_TRAINING_GUIDE.md     # Neuron compiler pitfalls + fixes

Why a decentralized diffusion model?

A single T2I / class-conditional model has to learn every subdomain of its training set. A DDM splits the data by semantic cluster and gives each cluster its own expert. At inference time a lightweight router picks the right expert per denoising step. In this MNIST toy:

  • Expert 0 only ever sees digits 0-4 during training.
  • Expert 1 only ever sees digits 5-9.
  • Neither expert can generate a digit outside its half; the router makes sure that's never asked of it.

This scales up to real image datasets where each expert can specialize to a semantic cluster (faces, landscapes, text, etc.) — see arxiv.org/abs/2507.05300.

About

2-node Trainium+GPU MNIST DDM demo

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors