Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Evo2 SAE Recipe

Train a sparse autoencoder on Evo2 (DNA language model) residual-stream activations.

Pipeline:

```
HF Savanna ckpt --convert--> MBridge ckpt
|
predict_evo2 --embedding-layer N (FASTA in, .pt out)
|
pt_to_parquet shim (.pt -> ActivationStore parquet shards)
|
train.py (TopK SAE)
```

The eval / dashboard stage from the esm2 recipe is intentionally not ported in v1.

## Quick start (1B model, single GPU)

```bash
bash scripts/1b.sh
```

This will:

1. Convert `arcinstitute/savanna_evo2_1b_base` to MBridge format
2. Run `predict_evo2` on the OpenGenome2 organelle FASTA, extracting layer-12 embeddings
3. Convert the .pt outputs to parquet shards
4. Train a TopK SAE (expansion=8, k=32)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "evo2-sae"
version = "0.1.0"
description = "Sparse Autoencoders for the Evo2 DNA language model"
readme = "README.md"
requires-python = ">=3.10"

dependencies = [
"sae",
"torch>=2.0",
"numpy>=1.20",
"tqdm>=4.60",
"pyarrow>=10.0",
]

# No package code lives here yet — the recipe is just an entry-point for
# scripts/ that depends on the shared `sae` workspace package. Declare no
# packages so setuptools doesn't try to discover anything.
[tool.setuptools]
packages = []

[tool.uv.sources]
sae = { workspace = true }
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/bin/bash
# Evo2 1B SAE pipeline: convert -> predict_evo2 -> pt_to_parquet -> train.
#
# Assumes:
# - bionemo-recipes/recipes/evo2_megatron has been built (.ci_build.sh) and
# its .venv is active, providing predict_evo2 + evo2_convert_savanna_to_mbridge.
# - The sae workspace package is importable in that same venv.
# - HF_TOKEN is set if Savanna checkpoint repo is gated.
#
# Override any of these by exporting before invocation.

set -euo pipefail

EVO2_MEGATRON_DIR="${EVO2_MEGATRON_DIR:-/workspace/bionemo-framework/bionemo-recipes/recipes/evo2_megatron}"
RECIPE_DIR="$(cd "$(dirname "$0")/.." && pwd)"

MODEL="${MODEL:-arcinstitute/savanna_evo2_1b_base}"
MODEL_SIZE="${MODEL_SIZE:-evo2_1b_base}"
LAYER="${LAYER:-12}"
# Trained context length. 1B = 8192. Bump for 7B/40B (context-extended).
CHUNK_BP="${CHUNK_BP:-8192}"

FASTA="${FASTA:-/data/interp/evo2/OpenGenome2/fasta/organelles/organelle_sequences.fasta.gz}"
WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}"

CKPT_DIR="${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge"
PREDICT_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_pt"
PARQUET_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet"
OUTPUT_DIR="${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}"

source "${EVO2_MEGATRON_DIR}/.venv/bin/activate"

echo "============================================================"
echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp (model trained context)"
echo "============================================================"
# chunk_fasta.py reads .gz directly and writes plain .fasta; no separate gunzip needed.
INPUT_STEM="$(basename "$FASTA")"
INPUT_STEM="${INPUT_STEM%.gz}"
INPUT_STEM="${INPUT_STEM%.fasta}"
CHUNKED_FASTA="${WORK_ROOT}/scratch/${INPUT_STEM}_chunked${CHUNK_BP}.fasta"
if [[ -f "$CHUNKED_FASTA" ]]; then
echo "Reusing existing chunked FASTA: $CHUNKED_FASTA"
else
python "${RECIPE_DIR}/scripts/chunk_fasta.py" \
--input "$FASTA" \
--output "$CHUNKED_FASTA" \
--window "$CHUNK_BP"
fi
FASTA="$CHUNKED_FASTA"

echo "============================================================"
echo "STEP 1: Convert Savanna -> MBridge"
echo "============================================================"
if [[ ! -f "${CKPT_DIR}/latest_checkpointed_iteration.txt" ]]; then
evo2_convert_savanna_to_mbridge \
--savanna-ckpt-path "$MODEL" \
--mbridge-ckpt-dir "$CKPT_DIR" \
--model-size "$MODEL_SIZE" \
--tokenizer-path "${EVO2_MEGATRON_DIR}/tokenizers/nucleotide_fast_tokenizer_512"
else
echo "Reusing existing checkpoint at $CKPT_DIR"
fi

echo "============================================================"
echo "STEP 2: Extract layer-${LAYER} embeddings (predict_evo2)"
echo "============================================================"
mkdir -p "$PREDICT_DIR"
if compgen -G "${PREDICT_DIR}/predictions__*.pt" > /dev/null; then
echo "Reusing existing .pt files in $PREDICT_DIR"
else
predict_evo2 \
--fasta "$FASTA" \
--ckpt-dir "$CKPT_DIR" \
--output-dir "$PREDICT_DIR" \
--embedding-layer "$LAYER" \
--micro-batch-size 1 \
--devices 1 \
--write-interval batch
fi

echo "============================================================"
echo "STEP 3: Convert .pt -> parquet ActivationStore"
echo "============================================================"
if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then
echo "Reusing existing parquet shards at $PARQUET_DIR"
else
python "${RECIPE_DIR}/scripts/pt_to_parquet.py" \
--predict-dir "$PREDICT_DIR" \
--output "$PARQUET_DIR" \
--model-name "$MODEL" \
--layer "$LAYER"
fi

echo "============================================================"
echo "STEP 4: Train TopK SAE"
echo "============================================================"
python "${RECIPE_DIR}/scripts/train.py" \
--cache-dir "$PARQUET_DIR" \
--model-path "$MODEL" \
--layer "$LAYER" \
--model-type topk \
--expansion-factor 8 --top-k 32 \
--auxk 64 --auxk-coef 0.03125 \
--init-pre-bias \
--n-epochs 3 \
--batch-size 4096 \
--lr 3e-4 \
--log-interval 50 \
--no-wandb \
--output-dir "$OUTPUT_DIR" \
--checkpoint-dir "${OUTPUT_DIR}/checkpoints" \
--checkpoint-steps 999999

echo "============================================================"
echo "DONE: SAE checkpoint at ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt"
echo "============================================================"
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Chunk a FASTA into <=N-bp windows so predict_evo2 stays inside the model's trained context.

Evo2 1B was trained with seq_length=8192; longer inputs OOM in the Hyena
fftconv path (intermediates scale super-linearly with L). For 7B/40B raise
--window to whatever those checkpoints were context-extended to.

Non-overlapping windows by default. Each chunk gets a header of the form
">{orig_id}:{start}-{end}" so downstream parquet can be back-mapped.
"""

import argparse
import gzip
from pathlib import Path


def parse_fasta(path: Path):
"""Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz)."""
opener = gzip.open if path.suffix == ".gz" else open
seq_id, parts = None, []
with opener(path, "rt") as f:
for line in f:
line = line.rstrip()
if line.startswith(">"):
if seq_id is not None:
yield seq_id, "".join(parts)
seq_id = line[1:].split()[0]
parts = []
else:
parts.append(line)
if seq_id is not None:
yield seq_id, "".join(parts)


def main():
"""Read input FASTA, write non-overlapping <=window-bp chunks to output FASTA."""
p = argparse.ArgumentParser()
p.add_argument("--input", type=Path, required=True)
p.add_argument("--output", type=Path, required=True)
p.add_argument("--window", type=int, default=8192)
args = p.parse_args()

n_in = n_out = bp_out = 0
args.output.parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as out:
for seq_id, seq in parse_fasta(args.input):
n_in += 1
for start in range(0, len(seq), args.window):
end = min(start + args.window, len(seq))
chunk = seq[start:end]
out.write(f">{seq_id}:{start}-{end}\n{chunk}\n")
n_out += 1
bp_out += len(chunk)

print(f"Chunked {n_in} sequences -> {n_out} chunks ({bp_out:,} bp) at window={args.window}")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convert predict_evo2 .pt outputs to SAE ActivationStore parquet shards.

predict_evo2 with --embedding-layer writes dicts of:
hidden_embeddings: [B, S, H] (bf16)
pad_mask: [B, S] (1 = valid token, 0 = padding)
seq_idx, tokens: metadata, ignored here

We read each file, mask out padding, flatten to [N_tokens, H], and append
to an ActivationStore so train.py's load_activations() can consume it.
"""

import argparse
import json
from pathlib import Path

import torch
from sae.activation_store import ActivationStore, ActivationStoreConfig
from tqdm import tqdm


def main():
"""Walk predict_evo2 .pt files, mask padding, and write to an ActivationStore."""
p = argparse.ArgumentParser()
p.add_argument("--predict-dir", type=Path, required=True, help="Dir containing predictions__*.pt")
p.add_argument("--output", type=Path, required=True, help="ActivationStore output dir")
p.add_argument("--model-name", type=str, required=True, help="Stamped into metadata.json")
p.add_argument("--layer", type=int, required=True, help="Stamped into metadata.json")
p.add_argument("--shard-size", type=int, default=100_000)
args = p.parse_args()

pt_files = sorted(args.predict_dir.rglob("predictions__*.pt"))
if not pt_files:
raise FileNotFoundError(f"No predictions__*.pt under {args.predict_dir}")

store = ActivationStore(args.output, ActivationStoreConfig(shard_size=args.shard_size))
n_sequences = 0
for pt in tqdm(pt_files, desc="pt->parquet"):
d = torch.load(pt, map_location="cpu", weights_only=False)
hidden = d["hidden_embeddings"]
mask = d["pad_mask"].bool()
flat = hidden[mask].float()
store.append(flat)
n_sequences += hidden.shape[0]

store.finalize(metadata={"model_name": args.model_name, "layer": args.layer, "n_sequences": n_sequences})
print(json.dumps(store.metadata, indent=2))


if __name__ == "__main__":
main()
Loading
Loading