Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ plans/
.snakemake/
.snakemake_timestamp

# From uv / local Python envs
.venv/
.venv-*/

# Sensitive environment variables
environment*

Expand Down
75 changes: 74 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,20 @@ This is a simple repo to provision evolutionary sequence trajectories from Nexts

# Installation

Install various dependencies with pip:
The project ships a `pyproject.toml` + `uv.lock` for [uv](https://docs.astral.sh/uv/),
and a `requirements.txt` for plain pip. Pick one.

**With uv (recommended):**

```
uv sync # creates .venv and installs all locked deps
uv run snakemake --help # run any command in the env without activating
```

`uv sync` is reproducible — it installs exactly the versions in `uv.lock`. To
refresh the lockfile after editing `pyproject.toml`, run `uv lock`.

**With pip:**

```
pip install -r requirements.txt
Expand Down Expand Up @@ -48,6 +61,10 @@ snakemake --configfile defaults/viral.yaml --cores 1 -p results # Generate tra
snakemake --configfile defaults/viral.yaml --cores 1 -p upload # Upload results to S3
```

If you set up the env with `uv sync`, prefix any of the snakemake commands in
this README with `uv run` (e.g. `uv run snakemake --configfile ... results`) —
or activate the venv with `source .venv/bin/activate` and use the bare commands.

Running `snakemake` with no target defaults to `results`. Running with no `--configfile` is a no-op (no analyses defined).

## Dataset config files
Expand Down Expand Up @@ -138,6 +155,62 @@ snakemake --configfile defaults/viral.yaml --cores 1 -p upload

This uploads `results/` to `s3://{bucket}/{s3_prefix}/`, where `s3_prefix` defaults to `trajectories` (set in `defaults/config.yaml`) and can be overridden per dataset config (e.g. `defaults/trellis.yaml` sets `s3_prefix: trellis-trajectories`). Requires `S3_BUCKET`, `AWS_ACCESS_KEY_ID`, and `AWS_SECRET_ACCESS_KEY` environment variables to be set.

## Downstream: JSONL conversion and cleaning

Two optional downstream stages turn the trajectory shards into json_sdiff JSONL
files for model training, and then apply quality filters.

```bash
# 1. Convert per-(mode,split) shards to JSONL
snakemake --configfile defaults/viral.yaml --cores 8 -p jsonl

# 2. Filter the JSONL by quality thresholds
snakemake --configfile defaults/viral.yaml --cores 8 -p clean
```

`jsonl` runs `scripts/fasta_to_jsonl.py` (vendored from pegasus-evals) in
parallel across every `{mode}-{split}-NNN.tar.zst` shard. Outputs land under
`results/{analysis}/jsonl/`:

```
results/n450-xs/jsonl/
├── shards/ # per-shard resume cache
├── forwards_train.jsonl
├── forwards_test.jsonl
├── pairwise_train.jsonl
├── pairwise_test.jsonl
├── combined_train.jsonl # forwards + pairwise (when trajectory_mode=both)
└── combined_test.jsonl
```

Per-shard outputs in `shards/` are reused on rerun (resumable: re-running skips
shards whose final `.jsonl` already exists; delete `shards/` to force a redo).

`clean` runs `scripts/clean_lib.py` (vendored from pegasus-datasets) via
`scripts/run_clean.py` to apply four record-level filters (`max_hunk_len`,
`gap_allele_frac`, `ref_gap_frac`, `mut_density`). Outputs land under
`results/{analysis}/jsonl_clean/` with one `.manifest.csv` per filtered file
capturing drop counts by reason.

Defaults live in `defaults/config.yaml` under the `jsonl:` and `clean:` keys.
Per-dataset overrides go under `analysis.{name}.jsonl` / `analysis.{name}.clean`:

```yaml
analysis:
n450-xs:
dataset: https://nextstrain.org/groups/trajectories/n450-xs
gene: nuc
seq_length: 450
jsonl:
max_raw_len: 8000
context_size: 12
clean:
mut_density: 0.3
```

Both targets respect `target_analyses` and `trajectory_mode` exactly like
`results`.

## Pooled rollup

After granular per-dataset shards are uploaded to S3, the `pooled` target combines every per-source shard under `s3://{bucket}/{s3_prefix}/` (excluding any existing `pooled/` subdir) into a single set of evenly-sized pooled shards for training:
Expand Down
194 changes: 194 additions & 0 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,200 @@ rule usher_provision_alignment_branches:
# End of UShER-specific rules
# ============================================================================

# ============================================================================
# Downstream: FASTA trajectory shards -> json_sdiff JSONL -> cleaned JSONL
# ============================================================================

# Resolve a (possibly per-analysis-overridden) setting from the top-level
# `jsonl:` / `clean:` config blocks.
def _ds_setting(section, key, analysis, default=None):
base = config.get(section, {}) or {}
per_ds = (
config.get("analysis", {})
.get(analysis, {})
.get(section, {})
or {}
)
if key in per_ds:
return per_ds[key]
return base.get(key, default)


# Which (mode, split) pairs are produced for the configured trajectory_mode.
def _jsonl_mode_splits():
pairs = []
if TRAJECTORY_MODE in ("forwards", "both"):
pairs += [("forwards", "train"), ("forwards", "test")]
if TRAJECTORY_MODE in ("pairwise", "both"):
pairs += [("pairwise", "train"), ("pairwise", "test")]
return pairs


def _jsonl_targets(analyses):
targets = []
for analysis in analyses:
for mode, split in _jsonl_mode_splits():
targets.append(f"results/{analysis}/jsonl/{mode}_{split}.jsonl")
# Always emit the combined files when both modes are produced.
if TRAJECTORY_MODE == "both":
targets += [
f"results/{analysis}/jsonl/combined_train.jsonl",
f"results/{analysis}/jsonl/combined_test.jsonl",
]
return targets


def _clean_kinds():
kinds = []
if TRAJECTORY_MODE in ("forwards", "both"):
kinds.append("forwards")
if TRAJECTORY_MODE in ("pairwise", "both"):
kinds.append("pairwise")
if TRAJECTORY_MODE == "both":
kinds.append("combined")
return kinds


def _clean_targets(analyses):
targets = []
for analysis in analyses:
for kind in _clean_kinds():
for split in ("train", "test"):
targets.append(f"results/{analysis}/jsonl_clean/{kind}_{split}.jsonl")
return targets


def _mode_done(analysis, mode):
"""The .done sentinel produced by the upstream trajectory rule for this mode."""
return (
f"results/{analysis}/.trajectories.done"
if mode == "forwards"
else f"results/{analysis}/.pairwise.done"
)


# Top-level aggregate targets — invoked as `snakemake jsonl` / `snakemake clean`.
rule jsonl:
input:
_jsonl_targets(ANALYSES)

rule clean:
input:
_clean_targets(ANALYSES)


rule jsonl_mode_split:
"""Convert all {mode}-{split} shards for one analysis to a single JSONL.

Drives the per-shard parallel conversion via scripts/shards_to_jsonl.sh,
which mirrors the resumable logic from the upstream pegasus-evals helper
(per-shard .jsonl cache under results/{analysis}/jsonl/shards/, atomic
.partial -> .jsonl renames, skip-if-exists on rerun). The actual FASTA ->
json_sdiff conversion is done by scripts/fasta_to_jsonl.py (vendored).
"""
wildcard_constraints:
mode = "forwards|pairwise",
split = "train|test",
input:
done = lambda wc: _mode_done(wc.analysis, wc.mode),
output:
jsonl = "results/{analysis}/jsonl/{mode}_{split}.jsonl",
params:
analysis_dir = "results/{analysis}",
shards_dir = "results/{analysis}/jsonl/shards",
script = "scripts/shards_to_jsonl.sh",
fasta_to_jsonl = "scripts/fasta_to_jsonl.py",
py = lambda wc: _ds_setting("jsonl", "python", wc.analysis, "python"),
parallel = lambda wc: _ds_setting("jsonl", "parallel", wc.analysis, 16),
max_raw_len = lambda wc: _ds_setting("jsonl", "max_raw_len", wc.analysis, 14000),
max_len = lambda wc: _ds_setting("jsonl", "max_len", wc.analysis, "") or "",
context_size = lambda wc: _ds_setting("jsonl", "context_size", wc.analysis, 16),
noprompt = lambda wc: "1" if _ds_setting("jsonl", "noprompt", wc.analysis, True) else "0",
shell:
"""
ANALYSIS_DIR={params.analysis_dir:q} \
MODE={wildcards.mode} \
SPLIT={wildcards.split} \
OUT_JSONL={output.jsonl:q} \
SHARDS_DIR={params.shards_dir:q} \
PY={params.py:q} \
FASTA_TO_JSONL={params.fasta_to_jsonl:q} \
PARALLEL={params.parallel} \
MAX_RAW_LEN={params.max_raw_len} \
MAX_LEN={params.max_len:q} \
CONTEXT_SIZE={params.context_size} \
NOPROMPT={params.noprompt} \
bash {params.script:q}
"""


rule jsonl_combined:
"""Concatenate forwards + pairwise JSONL for one split.

Only meaningful when both modes are produced (trajectory_mode == 'both').
"""
wildcard_constraints:
split = "train|test",
input:
forwards = "results/{analysis}/jsonl/forwards_{split}.jsonl",
pairwise = "results/{analysis}/jsonl/pairwise_{split}.jsonl",
output:
combined = "results/{analysis}/jsonl/combined_{split}.jsonl",
shell:
"""
cat {input.forwards:q} {input.pairwise:q} > {output.combined:q}
"""


rule clean_jsonl:
"""Filter a JSONL through scripts/clean_lib.py thresholds.

Operates on any of {forwards,pairwise,combined}_{train,test}.jsonl.
Emits the filtered JSONL alongside a per-file manifest CSV with drop
counts by reason (when clean.emit_manifest is true).
"""
wildcard_constraints:
kind = "forwards|pairwise|combined",
split = "train|test",
input:
jsonl = "results/{analysis}/jsonl/{kind}_{split}.jsonl",
output:
jsonl = "results/{analysis}/jsonl_clean/{kind}_{split}.jsonl",
manifest = "results/{analysis}/jsonl_clean/{kind}_{split}.manifest.csv",
params:
py = lambda wc: _ds_setting("clean", "python", wc.analysis, "python"),
max_hunk_len = lambda wc: _ds_setting("clean", "max_hunk_len", wc.analysis, 100),
gap_allele_frac = lambda wc: _ds_setting("clean", "gap_allele_frac", wc.analysis, 0.7),
ref_gap_frac = lambda wc: _ds_setting("clean", "ref_gap_frac", wc.analysis, 0.3),
mut_density = lambda wc: _ds_setting("clean", "mut_density", wc.analysis, 0.5),
parallel = lambda wc: _ds_setting("clean", "parallel", wc.analysis, 8),
chunk_size = lambda wc: _ds_setting("clean", "chunk_size", wc.analysis, 2000),
emit_manifest = lambda wc: bool(_ds_setting("clean", "emit_manifest", wc.analysis, True)),
shell:
"""
mkdir -p $(dirname {output.jsonl:q})
manifest_arg=""
if [ "{params.emit_manifest}" = "True" ]; then
manifest_arg="--manifest {output.manifest:q}"
else
# Always touch the manifest path so Snakemake sees the declared output.
: > {output.manifest:q}
fi
{params.py} scripts/run_clean.py \
--input {input.jsonl:q} \
--output {output.jsonl:q} \
--max-hunk-len {params.max_hunk_len} \
--gap-allele-frac {params.gap_allele_frac} \
--ref-gap-frac {params.ref_gap_frac} \
--mut-density {params.mut_density} \
--parallel {params.parallel} \
--chunk-size {params.chunk_size} \
$manifest_arg
"""


# ============================================================================

rule upload:
input:
trajectory_targets(ANALYSES)
Expand Down
29 changes: 29 additions & 0 deletions defaults/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,33 @@
s3_prefix: trajectories
trajectory_mode: both

# Downstream FASTA -> json_sdiff JSONL conversion.
# Used by `rule jsonl`; per-dataset overrides go under analysis.{name}.jsonl.
jsonl:
# Python interpreter for scripts/fasta_to_jsonl.py (needs numpy).
# `python` resolves to whatever is on PATH — when running via `uv run snakemake`
# this is the project venv's interpreter, which has numpy via biopython/augur.
python: python
# Hand-off to fasta_to_jsonl.py:
max_raw_len: 14000
max_len: "" # empty -> no length cap (matches fasta_to_jsonl.py default)
context_size: 16
noprompt: true
# Per-shard parallelism inside one (analysis, mode, split) tuple.
parallel: 16

# Downstream JSONL quality filter.
# Used by `rule clean`; per-dataset overrides go under analysis.{name}.clean.
clean:
# Python interpreter used to run scripts/run_clean.py.
python: python
# Thresholds (match clean_lib.py defaults).
max_hunk_len: 100
gap_allele_frac: 0.7
ref_gap_frac: 0.3
mut_density: 0.5
parallel: 8
chunk_size: 2000
emit_manifest: true

analysis: {}
24 changes: 24 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[project]
name = "trajectories"
version = "0.1.0"
description = "Provision evolutionary sequence trajectories from Nextstrain trees."
readme = "README.md"
license = { file = "LICENSE.md" }
requires-python = ">=3.10"
dependencies = [
"biopython",
"boto3",
"nextstrain-augur",
"numpy",
"pulp<3.0",
"requests",
"snakemake",
"tqdm",
"zstandard",
]

# This project ships pipeline scripts invoked by the Snakefile, not an importable
# package — uv should manage the dependency env without trying to build/install
# the project itself.
[tool.uv]
package = false
Loading