A lightweight diffusion-based model for binary protein-protein interaction (PPI) structure prediction. Inspired by AlphaFold3 and Boltz-2, but designed to be small (<30M parameters) and trainable on a single consumer GPU. Includes a web frontend for interactive visualization.
Carefully curated held-out test set example of predicted vs ground truth complexes. Green-blue: ground truth. Red-yellow: predicted.
TinyFold predicts the 3D structure of two interacting protein chains using a two-stage approach:
-
Stage 1 (Residue Diffusion): Predict residue centroid positions using a diffusion model operating on L tokens (one per residue) instead of 4L atoms.
-
Stage 2 (Atom Refinement): Refine centroid predictions to full backbone atom coordinates (N, CA, C, O) using local attention.
This hierarchical design is motivated by the observation that backbone topology is the hard problem—local bond geometry is well-constrained by chemistry.
- Compact: ~28M parameters total (13.8M Stage 1 + 14.2M Stage 2)
- Efficient: First-stage decoder operates on residue-level tokens, not atoms (4x fewer tokens in the first stage)
- Modular: Train stages independently or end-to-end
- Practical: Single GPU training (RTX 4070 Ti, ~8 hours for Stage 1)
Complex folding couples global arrangement (chain–chain positioning) with local atomic detail (side-chain packing, interface chemistry). We decouple these by predicting structure at two resolutions:
-
Residue-level diffusion (global scaffold): sample residue anchors for each chain to capture fold topology and relative orientation in the complex.
-
All-atom refinement (local consistency): condition on the residue scaffold to generate full atomic coordinates, resolving side chains and interface packing.
Why this helps:
-
Efficiency / compactness: modeling long-range geometry at residue resolution reduces sequence length and degrees of freedom seen by the attention decoder, enabling a smaller model without sacrificing the ability to represent inter-chain organization.
-
Interaction learning without explicit pair features: rather than maintaining an explicit pairwise tensor (as in Pairformer-like designs), residue–residue dependencies are learned implicitly through attention over the coarse structural scaffold.
-
Stability and controllability: the global scaffold constrains refinement, reducing search complexity for the all-atom stage and making it easier to incorporate constraints (fixed subunits, known domains, interface restraints).
-
Partial-known complexes: if one partner is known, keep it fixed and diffuse only the unknown partner’s residue scaffold, then refine—turning full complex prediction into a cheaper conditional docking-style problem.
-
Pocket-conditioned ligand placement: given a predefined binding pocket, first diffuse a coarse ligand representation/pose relative to pocket anchors, then refine to an all-atom, chemically valid pose.
The diffusion model predicts clean residue centroids from noisy inputs:
- ResidueEncoder (Trunk): Processes sequence, chain IDs, and positions through a 9-layer Transformer. Runs once per sample to produce conditioning tokens.
- DiffusionTransformer (Denoiser): Iteratively denoises centroid positions over T=50 steps using Adaptive LayerNorm conditioning.
- Output: Predicted centroid positions [L, 3]
One-shot prediction of 4 backbone atoms per residue:
- GlobalTransformer: 6-layer Transformer captures inter-residue context
- LocalAtomAttention: Attention within each residue's 4 atoms predicts offsets from centroid
- Output: Backbone atom positions [L, 4, 3] (N, CA, C, O)
Beyond the primary MSE loss on coordinates, Stage 2 (atom refinement) uses geometry-based auxiliary losses to enforce chemically valid backbone structures. These losses operate on the predicted [L, 4, 3] atom coordinates and are not applied during Stage 1 (residue diffusion), which uses only MSE + distance consistency.
Penalizes deviations from ideal backbone bond lengths:
where distances are in Ångströms.
Enforces tetrahedral geometry at Cα and planar geometry at the peptide bond:
The omega dihedral angle $\omega = \text{CA}i\text{-C}i\text{-N}{i+1}\text{-CA}{i+1}$ should be ~180° (trans) or ~0° (cis):
This allows both trans (~99.5% of peptide bonds) and cis configurations.
Proteins use L-amino acids exclusively, which constrains the stereochemistry:
Carbonyl O Chirality: The carbonyl oxygen must be on the correct side of the peptide plane (trans to the next Cα):
where
Virtual Cβ Chirality (experimental, currently disabled): Even without side chains, L-amino acid handedness can be enforced by computing a virtual Cβ position and checking its improper dihedral:
This loss is implemented but has not been enabled in any training run (weight=0.0).
Preserves pairwise Cα distances between the prediction and ground truth for contact residues (within 10Å):
where
| Aspect | TinyFold | AlphaFold3 | Boltz-2 |
|---|---|---|---|
| Diffusion target | L residue centroids | All atoms | All atoms |
| Pair features | Implicit in attention | Explicit Pairformer | Explicit Pairformer |
| Per-step alignment | Kabsch (optional) | No | Kabsch |
| Model size | 28M params | 600M+ params | 700M+ params |
| Training hardware | Single GPU | TPU pod | Multi-GPU cluster |
- Python 3.10+
- PyTorch 2.0+
- CUDA-capable GPU (12GB+ VRAM)
# Stage 1: Residue diffusion
python scripts/train_resfold.py \
--mode stage1_only \
--n_train 20000 \
--n_steps 50000 \
--output_dir outputs/stage1
# Stage 2: Atom refinement (using Stage 1 predictions)
python scripts/train_resfold.py \
--mode stage2_only \
--stage1_dir outputs/stage1 \
--n_steps 20000 \
--output_dir outputs/stage2python scripts/predict.py \
--checkpoint outputs/stage1/best_model.pt \
--pdb1 chain_a.pdb \
--pdb2 chain_b.pdb \
--output prediction.pdbTwo frontend entrypoints:
web/: model evaluation UI (browse samples, run/load predictions)web-light/: static lightweight showcase (few train/test GT vs prediction examples)
Launch the model evaluation frontend:
cd web
../.venv/Scripts/python.exe server.py
# Open http://127.0.0.1:5001Launch the lightweight showcase (stdlib server, runnable right after clone):
cd web-light
python server.py --port 5002
# Open http://127.0.0.1:5002See doc/frontend.md for detailed documentation including:
- Generating cached predictions
- Regenerating
web-lightshowcase data fromassets/
I currently only use the DIPS-Plus dataset:
- 28,352 binary protein complexes
- Backbone atoms only (N, CA, C, O)
Download and preprocess:
python scripts/data/prepare_data.py --output-dir data/processed- Boltz-2 style per-step Kabsch alignment (available in all samplers)
- Proper benchmarking (DockQ, lDDT, interface metrics)
- Web frontend for visualization (
web/andweb-light/) - Energy-based auxiliary losses (Lennard-Jones, electrostatics)
- Extension to small molecules / DNA / other macromolecules