STMDiT is a diffusion transformer for histopathology image synthesis that conditions on both a slide-level pathology embedding (UNI2-h) and a spot-level transcriptomic profile encoded by a frozen single-cell foundation model (CancerFoundation). The model produces 256x256 H&E patches that match the morphology and gene-expression profile of the conditioning spot, supports dual classifier-free guidance to trade off the two modalities, and generalises to held-out tissues without retraining.
All 28 paper-row checkpoints are released as a single Hugging Face umbrella
repo. Each subfolder contains an EMA-only model.pt, the original
training_config.yaml, and a per-model card.
Hugging Face: https://huggingface.co/stmdit-anon/stmdit-checkpoints
Download a single model with huggingface_hub:
from huggingface_hub import snapshot_download
ckpt = snapshot_download(
repo_id="stmdit-anon/stmdit-checkpoints",
allow_patterns="adaln-ddpm-p06/*",
)
# ckpt / "adaln-ddpm-p06" / "model.pt"row_id |
Paper label |
|---|---|
pixcell-b |
PixCell-B |
pixcell-flow-b |
PixCell-Flow-B |
adaln-ddpm-p01 |
PixCell-GE-B (p=0.1) |
adaln-ddpm-p02 |
PixCell-GE-B-p02 |
adaln-ddpm-p03 |
PixCell-GE-B-p03 |
adaln-ddpm-p05 |
PixCell-GE-B-p05 |
adaln-ddpm-p06 |
PixCell-GE-B-p06 |
adaln-flow-p01 |
PixCell-Flow-GE-B (p=0.1) |
adaln-flow-p02 |
PixCell-Flow-GE-B-p02 |
adaln-flow-p03 |
PixCell-Flow-GE-B-p03 |
adaln-flow-p05 |
PixCell-Flow-GE-B-p05 |
xattn-direct-p01 |
XAttn-Direct (p=0.1) |
xattn-gsa-p01 |
XAttn-GSA (p=0.1) |
xattn-perceiver-p01 |
XAttn-Perceiver (p=0.1) |
xattn-pma-p01 |
XAttn-PMA (p=0.1) |
xattn-perceiver-p05 |
XAttn-Perceiver-p05 |
xattn-perceiver-p06 |
XAttn-Perceiver-p06 |
xattn-pma-p05 |
XAttn-PMA-p05 |
xattn-pma-p06 |
XAttn-PMA-p06 |
ptpl-adaln-p05 |
PTPL-AdaLN-B (p=0.5) |
ptpl-adaln-p06 |
PTPL-AdaLN-B-p06 |
ptpl-adaln-p07 |
PTPL-AdaLN-B-p07 |
ptpl-xattn-perceiver-p05 |
PTPL-XAttn-Perceiver-B (p=0.5) |
ptpl-xattn-perceiver-p06 |
PTPL-XAttn-Perceiver-B-p06 |
ptpl-xattn-perceiver-p07 |
PTPL-XAttn-Perceiver-B-p07 |
ptpl-xattn-pma-p05 |
PTPL-XAttn-PMA-B (p=0.5) |
ptpl-xattn-pma-p06 |
PTPL-XAttn-PMA-B-p06 |
ptpl-xattn-pma-p07 |
PTPL-XAttn-PMA-B-p07 |
git clone <this-repo>
cd stmdit
pip install -e .
pip install huggingface_hub matplotlib jupyterYou also need a local copy of the SD3 VAE (Stability AI) and the
CancerFoundation gene-expression encoder; their paths feed into the
inference pipeline as vae_path and cf_model_dir.
from huggingface_hub import snapshot_download
import numpy as np, torch, yaml
from source.inference import InferencePipeline
from source.data.gene_vocab import load_gene_vocab
repo_root = snapshot_download(
repo_id="stmdit-anon/stmdit-checkpoints",
allow_patterns="adaln-ddpm-p06/*",
)
model_dir = f"{repo_root}/adaln-ddpm-p06"
cfg = yaml.safe_load(open(f"{model_dir}/training_config.yaml"))["model"]
feat = np.load("data/demo_features.npz")
uni = torch.from_numpy(feat["uni"]).float()
ge = torch.from_numpy(feat["ge"]).float()
ge_binned = torch.from_numpy(feat["ge_binned"]).long()
pipe = InferencePipeline.from_pretrained(
checkpoint_path=f"{model_dir}/model.pt",
vae_path="<path-to-sd3-vae>",
device="cuda",
dtype="float32",
model_type=cfg.get("type", "pixart_ge"),
variant=cfg.get("variant", "B"),
ge_encoder_type="cancerfoundation",
ge_hidden_dim=cfg.get("ge_hidden_dim", 512),
cf_model_dir="<path-to-cancerfoundation>",
cf_gene_list=load_gene_vocab(),
cf_freeze_backbone=True,
)
images = pipe.generate(
num_samples=4,
conditioning=uni,
gene_expression=ge,
gene_expression_binned=ge_binned,
guidance_scale=4.0,
guidance_scale_ge=3.0,
num_inference_steps=50,
sampler="ddim",
seed=42,
)
images[0].save("sample_0.png")A 6-cell notebook that loads the four demo spots from
data/demo_features.npz, downloads adaln-ddpm-p06 from Hugging Face, and
plots the four generated patches in a 2x2 grid lives at
notebooks/demo_inference.ipynb.
Apache-2.0. See LICENSE.