Skip to content

ratschlab/stmdit

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

STMDiT — Transcriptomics-Conditioned Virtual Tissue Synthesis

STMDiT is a diffusion transformer for histopathology image synthesis that conditions on both a slide-level pathology embedding (UNI2-h) and a spot-level transcriptomic profile encoded by a frozen single-cell foundation model (CancerFoundation). The model produces 256x256 H&E patches that match the morphology and gene-expression profile of the conditioning spot, supports dual classifier-free guidance to trade off the two modalities, and generalises to held-out tissues without retraining.

Pretrained models

All 28 paper-row checkpoints are released as a single Hugging Face umbrella repo. Each subfolder contains an EMA-only model.pt, the original training_config.yaml, and a per-model card.

Hugging Face: https://huggingface.co/stmdit-anon/stmdit-checkpoints

Download a single model with huggingface_hub:

from huggingface_hub import snapshot_download
ckpt = snapshot_download(
    repo_id="stmdit-anon/stmdit-checkpoints",
    allow_patterns="adaln-ddpm-p06/*",
)
# ckpt / "adaln-ddpm-p06" / "model.pt"
row_id Paper label
pixcell-b PixCell-B
pixcell-flow-b PixCell-Flow-B
adaln-ddpm-p01 PixCell-GE-B (p=0.1)
adaln-ddpm-p02 PixCell-GE-B-p02
adaln-ddpm-p03 PixCell-GE-B-p03
adaln-ddpm-p05 PixCell-GE-B-p05
adaln-ddpm-p06 PixCell-GE-B-p06
adaln-flow-p01 PixCell-Flow-GE-B (p=0.1)
adaln-flow-p02 PixCell-Flow-GE-B-p02
adaln-flow-p03 PixCell-Flow-GE-B-p03
adaln-flow-p05 PixCell-Flow-GE-B-p05
xattn-direct-p01 XAttn-Direct (p=0.1)
xattn-gsa-p01 XAttn-GSA (p=0.1)
xattn-perceiver-p01 XAttn-Perceiver (p=0.1)
xattn-pma-p01 XAttn-PMA (p=0.1)
xattn-perceiver-p05 XAttn-Perceiver-p05
xattn-perceiver-p06 XAttn-Perceiver-p06
xattn-pma-p05 XAttn-PMA-p05
xattn-pma-p06 XAttn-PMA-p06
ptpl-adaln-p05 PTPL-AdaLN-B (p=0.5)
ptpl-adaln-p06 PTPL-AdaLN-B-p06
ptpl-adaln-p07 PTPL-AdaLN-B-p07
ptpl-xattn-perceiver-p05 PTPL-XAttn-Perceiver-B (p=0.5)
ptpl-xattn-perceiver-p06 PTPL-XAttn-Perceiver-B-p06
ptpl-xattn-perceiver-p07 PTPL-XAttn-Perceiver-B-p07
ptpl-xattn-pma-p05 PTPL-XAttn-PMA-B (p=0.5)
ptpl-xattn-pma-p06 PTPL-XAttn-PMA-B-p06
ptpl-xattn-pma-p07 PTPL-XAttn-PMA-B-p07

Setup

git clone <this-repo>
cd stmdit
pip install -e .
pip install huggingface_hub matplotlib jupyter

You also need a local copy of the SD3 VAE (Stability AI) and the CancerFoundation gene-expression encoder; their paths feed into the inference pipeline as vae_path and cf_model_dir.

Quickstart

from huggingface_hub import snapshot_download
import numpy as np, torch, yaml
from source.inference import InferencePipeline
from source.data.gene_vocab import load_gene_vocab

repo_root = snapshot_download(
    repo_id="stmdit-anon/stmdit-checkpoints",
    allow_patterns="adaln-ddpm-p06/*",
)
model_dir = f"{repo_root}/adaln-ddpm-p06"
cfg = yaml.safe_load(open(f"{model_dir}/training_config.yaml"))["model"]

feat = np.load("data/demo_features.npz")
uni = torch.from_numpy(feat["uni"]).float()
ge = torch.from_numpy(feat["ge"]).float()
ge_binned = torch.from_numpy(feat["ge_binned"]).long()

pipe = InferencePipeline.from_pretrained(
    checkpoint_path=f"{model_dir}/model.pt",
    vae_path="<path-to-sd3-vae>",
    device="cuda",
    dtype="float32",
    model_type=cfg.get("type", "pixart_ge"),
    variant=cfg.get("variant", "B"),
    ge_encoder_type="cancerfoundation",
    ge_hidden_dim=cfg.get("ge_hidden_dim", 512),
    cf_model_dir="<path-to-cancerfoundation>",
    cf_gene_list=load_gene_vocab(),
    cf_freeze_backbone=True,
)

images = pipe.generate(
    num_samples=4,
    conditioning=uni,
    gene_expression=ge,
    gene_expression_binned=ge_binned,
    guidance_scale=4.0,
    guidance_scale_ge=3.0,
    num_inference_steps=50,
    sampler="ddim",
    seed=42,
)
images[0].save("sample_0.png")

Demo notebook

A 6-cell notebook that loads the four demo spots from data/demo_features.npz, downloads adaln-ddpm-p06 from Hugging Face, and plots the four generated patches in a 2x2 grid lives at notebooks/demo_inference.ipynb.

License

Apache-2.0. See LICENSE.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors