Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Evo2 SAE Recipe

Train a sparse autoencoder on Evo2 (DNA language model) residual-stream activations.

Pipeline:

```
HF Savanna ckpt --convert--> MBridge ckpt
|
extract.py (FASTA in, ActivationStore parquet shards out)
|
train.py (TopK SAE)
```

`extract.py` monkey-patches `predict_evo2`'s writer to stream parquet shards inline,
so there is no `.pt` intermediate and no separate shim step.

The eval / dashboard stage from the esm2 recipe is intentionally not ported in v1.

## Quick start (1B model, 4 GPU)

```bash
bash scripts/1b.sh
```

This will:

1. Convert `arcinstitute/savanna_evo2_1b_base` to MBridge format
2. Run `extract.py` on the OpenGenome2 organelle FASTA, streaming layer-12
activations directly to parquet shards (no `.pt` intermediate)
3. Train a TopK SAE (expansion=8, k=32, auxk=512)

Common overrides:

```bash
# Different layer, different FASTA, tagged output paths
LAYER=22 FASTA=/data/.../prokeuk_25M.fasta RUN_TAG=25M_prokeuk bash scripts/1b.sh

# Skip extraction (assumes parquet already exists at PARQUET_DIR)
TRAIN_ONLY=1 PARQUET_DIR=/data/.../parquet_25M_prokeuk bash scripts/1b.sh

# Sweep hyperparams
LAYER=22 RUN_TAG=auxk2048 AUXK=2048 N_EPOCHS=4 bash scripts/1b.sh
```

See the top of `scripts/1b.sh` for the full list of env-overridable variables
(`MODEL`, `LAYER`, `CHUNK_BP`, `FASTA`, `RUN_TAG`, `MAX_TOKENS`, `MICRO_BATCH`,
`DEVICES`, `EXPANSION_FACTOR`, `TOP_K`, `AUXK`, `AUXK_COEF`,
`DEAD_TOKENS_THRESHOLD`, `N_EPOCHS`, `LR`, `WANDB_API_KEY`, `WANDB_PROJECT`,
`WANDB_RUN_NAME`, `TRAIN_ONLY`).
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "evo2-sae"
version = "0.1.0"
description = "Sparse Autoencoders for the Evo2 DNA language model"
readme = "README.md"
requires-python = ">=3.10"

dependencies = [
"sae",
"torch>=2.0",
"numpy>=1.20",
"tqdm>=4.60",
"pyarrow>=10.0",
]

# No package code lives here yet — the recipe is just an entry-point for
# scripts/ that depends on the shared `sae` workspace package. Declare no
# packages so setuptools doesn't try to discover anything.
[tool.setuptools]
packages = []

[tool.uv.sources]
sae = { workspace = true }
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/bin/bash
# Evo2 1B SAE pipeline: convert -> extract (streaming parquet) -> train.
#
# Assumes:
# - bionemo-recipes/recipes/evo2_megatron has been built (.ci_build.sh) and
# its .venv is active, providing predict_evo2 + evo2_convert_savanna_to_mbridge.
# - The sae workspace package is importable in that same venv.
# - HF_TOKEN is set if the Savanna checkpoint repo is gated.
#
# Override any of these by exporting before invocation.

set -euo pipefail

EVO2_MEGATRON_DIR="${EVO2_MEGATRON_DIR:-/workspace/bionemo-framework/bionemo-recipes/recipes/evo2_megatron}"
RECIPE_DIR="$(cd "$(dirname "$0")/.." && pwd)"

MODEL="${MODEL:-arcinstitute/savanna_evo2_1b_base}"
MODEL_SIZE="${MODEL_SIZE:-evo2_1b_base}"
LAYER="${LAYER:-12}"
# Trained context length. 1B = 8192. Bump for 7B/40B (context-extended).
CHUNK_BP="${CHUNK_BP:-8192}"

FASTA="${FASTA:-/data/interp/evo2/OpenGenome2/fasta/organelles/organelle_sequences.fasta.gz}"
WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}"

# Default output paths can be overridden per-run. Set RUN_TAG to suffix the
# activation and SAE paths at once (e.g. RUN_TAG=100M_mixed -> ..._parquet_100M_mixed),
# or override each path individually for full control.
RUN_TAG="${RUN_TAG:-}"
_SUFFIX="${RUN_TAG:+_${RUN_TAG}}"

CKPT_DIR="${CKPT_DIR:-${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge}"
PARQUET_DIR="${PARQUET_DIR:-${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet${_SUFFIX}}"
OUTPUT_DIR="${OUTPUT_DIR:-${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}${_SUFFIX}}"

source "${EVO2_MEGATRON_DIR}/.venv/bin/activate"

# TRAIN_ONLY=1 skips chunk/convert/extract against a cached parquet.
if [[ "${TRAIN_ONLY:-0}" == "1" ]]; then
echo "============================================================"
echo "TRAIN_ONLY=1 — skipping chunk / convert / extract;"
echo "expecting an existing parquet at: $PARQUET_DIR"
echo "============================================================"
if [[ ! -f "${PARQUET_DIR}/metadata.json" ]]; then
echo "ERROR: TRAIN_ONLY=1 but no parquet at $PARQUET_DIR"
exit 1
fi
else

echo "============================================================"
echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp (model trained context)"
echo "============================================================"
# chunk_fasta.py reads .gz directly and writes plain .fasta; no separate gunzip needed.
INPUT_STEM="$(basename "$FASTA")"
INPUT_STEM="${INPUT_STEM%.gz}"
INPUT_STEM="${INPUT_STEM%.fasta}"
CHUNKED_FASTA="${WORK_ROOT}/scratch/${INPUT_STEM}_chunked${CHUNK_BP}.fasta"
if [[ -f "$CHUNKED_FASTA" ]]; then
echo "Reusing existing chunked FASTA: $CHUNKED_FASTA"
else
python "${RECIPE_DIR}/scripts/chunk_fasta.py" \
--input "$FASTA" \
--output "$CHUNKED_FASTA" \
--window "$CHUNK_BP"
fi
FASTA="$CHUNKED_FASTA"

echo "============================================================"
echo "STEP 1: Convert Savanna -> MBridge"
echo "============================================================"
if [[ ! -f "${CKPT_DIR}/latest_checkpointed_iteration.txt" ]]; then
evo2_convert_savanna_to_mbridge \
--savanna-ckpt-path "$MODEL" \
--mbridge-ckpt-dir "$CKPT_DIR" \
--model-size "$MODEL_SIZE" \
--tokenizer-path "${EVO2_MEGATRON_DIR}/tokenizers/nucleotide_fast_tokenizer_512"
else
echo "Reusing existing MBridge checkpoint at $CKPT_DIR"
fi

echo "============================================================"
echo "STEP 2: Extract layer-${LAYER} activations directly to parquet"
echo "============================================================"
# extract.py monkey-patches predict_evo2's writer to stream parquet shards
# inline; no .pt intermediate and no separate shim. MAX_TOKENS=0 = uncapped.
if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then
echo "Reusing existing parquet shards at $PARQUET_DIR"
else
torchrun --nproc_per_node "${DEVICES:-4}" "${RECIPE_DIR}/scripts/extract.py" \
--activation-store-dir "$PARQUET_DIR" \
--max-tokens "${MAX_TOKENS:-0}" \
--model-name "$MODEL" \
--fasta "$FASTA" \
--ckpt-dir "$CKPT_DIR" \
--embedding-layer "$LAYER" \
--micro-batch-size "${MICRO_BATCH:-4}"
fi

fi # end if TRAIN_ONLY

echo "============================================================"
echo "STEP 3: Train TopK SAE"
echo "============================================================"
# Wandb is enabled iff WANDB_API_KEY is in the env. WANDB_PROJECT/RUN can be overridden.
WANDB_FLAGS=("--no-wandb")
if [[ -n "${WANDB_API_KEY:-}" ]]; then
WANDB_FLAGS=(
"--wandb"
"--wandb-project" "${WANDB_PROJECT:-evo2-sae}"
)
if [[ -n "${WANDB_RUN_NAME:-}" ]]; then
WANDB_FLAGS+=("--wandb-run-name" "$WANDB_RUN_NAME")
fi
fi

torchrun --nproc_per_node "${DEVICES:-4}" "${RECIPE_DIR}/scripts/train.py" \
--cache-dir "$PARQUET_DIR" \
--model-path "$MODEL" \
--layer "$LAYER" \
--model-type topk \
--expansion-factor "${EXPANSION_FACTOR:-8}" \
--top-k "${TOP_K:-32}" \
--auxk "${AUXK:-512}" \
--auxk-coef "${AUXK_COEF:-0.03125}" \
--dead-tokens-threshold "${DEAD_TOKENS_THRESHOLD:-10000000}" \
--init-pre-bias \
--n-epochs "${N_EPOCHS:-3}" \
--batch-size 4096 \
--dp-size "${DEVICES:-4}" \
--lr "${LR:-3e-4}" \
--log-interval 50 \
"${WANDB_FLAGS[@]}" \
--output-dir "$OUTPUT_DIR" \
--checkpoint-dir "${OUTPUT_DIR}/checkpoints" \
--checkpoint-steps 999999

echo "============================================================"
echo "DONE: SAE checkpoint at ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt"
echo "============================================================"
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Chunk a FASTA into <=N-bp windows so predict_evo2 stays inside the model's trained context.

Evo2 1B was trained with seq_length=8192; longer inputs OOM in the Hyena
fftconv path (intermediates scale super-linearly with L). For 7B/40B raise
--window to whatever those checkpoints were context-extended to.

Non-overlapping windows by default. Each chunk gets a header of the form
">{orig_id}:{start}-{end}" so downstream parquet can be back-mapped.
"""

import argparse
import gzip
from pathlib import Path


def parse_fasta(path: Path):
"""Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz)."""
opener = gzip.open if path.suffix == ".gz" else open
seq_id, parts = None, []
with opener(path, "rt") as f:
for line in f:
line = line.rstrip()
if line.startswith(">"):
if seq_id is not None:
yield seq_id, "".join(parts)
seq_id = line[1:].split()[0]
parts = []
else:
parts.append(line)
if seq_id is not None:
yield seq_id, "".join(parts)


def main():
"""Read input FASTA, write non-overlapping <=window-bp chunks to output FASTA."""
p = argparse.ArgumentParser()
p.add_argument("--input", type=Path, required=True)
p.add_argument("--output", type=Path, required=True)
p.add_argument("--window", type=int, default=8192)
args = p.parse_args()

n_in = n_out = bp_out = 0
args.output.parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as out:
for seq_id, seq in parse_fasta(args.input):
n_in += 1
for start in range(0, len(seq), args.window):
end = min(start + args.window, len(seq))
chunk = seq[start:end]
out.write(f">{seq_id}:{start}-{end}\n{chunk}\n")
n_out += 1
bp_out += len(chunk)

print(f"Chunked {n_in} sequences -> {n_out} chunks ({bp_out:,} bp) at window={args.window}")


if __name__ == "__main__":
main()
Loading
Loading