Skip to content
129 changes: 119 additions & 10 deletions malariagen_data/anoph/heterozygosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,108 @@ def _sample_count_het(

return sample_id, sample_set, windows, counts

def cohort_count_het(
self,
region: Region,
df_cohort_samples: pd.DataFrame,
sample_sets: Optional[base_params.sample_sets],
window_size: het_params.window_size,
site_mask: Optional[base_params.site_mask],
chunks: base_params.chunks,
inline_array: base_params.inline_array,
):
"""Compute windowed heterozygosity counts for multiple samples in a cohort.

This method efficiently computes heterozygosity for all samples by loading
SNP data once and computing across all samples, rather than calling snp_calls()
repeatedly for each sample. This vectorized approach provides substantial
performance improvements for large cohorts.

Parameters
----------
region : Region
Genome region to analyze.
df_cohort_samples : pd.DataFrame
Sample metadata dataframe with at least 'sample_id' column.
sample_sets : str, optional
Sample set identifier(s).
window_size : int
Size of sliding windows for heterozygosity computation.
site_mask : str, optional
Site mask to apply.
chunks : str or int, dict
Chunk size for dask arrays.
inline_array : bool
Whether to inline arrays.

Returns
-------
dict
Mapping from sample_id to (windows, counts) tuple, where:
- windows: array of shape (n_windows, 2) with [start, stop] positions
- counts: array of shape (n_windows,) with heterozygous site counts per window
"""
debug = self._log.debug

# Extract sample IDs from cohort dataframe
sample_ids = df_cohort_samples["sample_id"].values

debug("access SNPs for all cohort samples")
Comment thread
kunal-10-cloud marked this conversation as resolved.
# Load SNP data once for all samples in cohort
ds_snps = self.snp_calls(
region=region,
sample_sets=sample_sets,
site_mask=site_mask,
chunks=chunks,
inline_array=inline_array,
)

# Subset to cohort samples to ensure correct indexing
ds_snps = ds_snps.set_index(samples="sample_id").sel(samples=sample_ids)
sample_id_to_idx = {sid: idx for idx, sid in enumerate(sample_ids)}

# SNP positions (same for all samples)
pos = ds_snps["variant_position"].values

# guard against window_size exceeding available sites
if pos.shape[0] < window_size:
raise ValueError(
f"Not enough sites ({pos.shape[0]}) for window size "
f"({window_size}). Please reduce the window size or "
f"use different site selection criteria."
)

# Compute window coordinates once (same for all samples)
windows = allel.moving_statistic(
values=pos,
statistic=lambda x: [x[0], x[-1]],
size=window_size,
)

# access genotypes for all samples
gt_data = ds_snps["call_genotype"].data

# Compute windowed heterozygosity for each sample and cache results
results = {}
for sample_id, sample_idx in sample_id_to_idx.items():
# Compute heterozygous genotypes for this sample only to avoid
# materializing the full (variants, samples) array in memory.
debug(f"Compute heterozygous genotypes for sample {sample_id}")
gt_sample = allel.GenotypeDaskVector(gt_data[:, sample_idx, :])
with self._dask_progress(desc="Compute heterozygous genotypes"):
is_het_sample = gt_sample.is_het().compute()

# compute windowed heterozygosity for this sample
counts = allel.moving_statistic(
values=is_het_sample,
statistic=np.sum,
size=window_size,
)

results[sample_id] = (windows, counts)

return results

@property
def _roh_hmm_cache_name(self):
return "roh_hmm_v1"
Expand Down Expand Up @@ -816,18 +918,25 @@ def cohort_heterozygosity(
)
n_samples = len(df_cohort_samples)

# Compute heterozygosity for each sample and take the mean.
# Compute heterozygosity for all samples in the cohort using cohort_count_het().
# This public method loads SNP data once and computes across all samples,
# providing substantial speedup over sequential per-sample processing.
cohort_het_results = self.cohort_count_het(
region=region_prepped,
df_cohort_samples=df_cohort_samples,
sample_sets=sample_sets,
window_size=window_size,
site_mask=site_mask,
chunks=chunks,
inline_array=inline_array,
)

# Compute per-sample means and aggregate.
het_values = []
for sample_id in df_cohort_samples["sample_id"]:
df_het = self.sample_count_het(
sample=sample_id,
region=region_prepped,
window_size=window_size,
site_mask=site_mask,
chunks=chunks,
inline_array=inline_array,
)
het_values.append(df_het["heterozygosity"].mean())
_, counts = cohort_het_results[sample_id]
het_mean = np.mean(counts / window_size)
het_values.append(het_mean)

results.append(
{
Expand Down
73 changes: 73 additions & 0 deletions tests/anoph/test_heterozygosity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random

import bokeh.models
import numpy as np
import pandas as pd
import pytest
from pytest_cases import parametrize_with_cases
Expand Down Expand Up @@ -273,3 +274,75 @@ def test_cohort_heterozygosity(fixture, api: AnophelesHetAnalysis):
assert (df["n_samples"] > 0).all()
assert (df["mean_heterozygosity"] >= 0).all()
assert (df["mean_heterozygosity"] <= 1).all()


@parametrize_with_cases("fixture,api", cases=".")
def test_cohort_count_het_regression(fixture, api: AnophelesHetAnalysis):
"""Regression test: cohort method produces identical results to sequential method.

This test verifies that the cohort_count_het() method produces
numerically identical heterozygosity values as the sequential per-sample approach.
"""
from malariagen_data.util import _parse_single_region
from malariagen_data.anoph import base_params

# Set up test parameters.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_set = random.choice(all_sample_sets)
region = random.choice(api.contigs)
window_size = 20_000

# Get sample metadata for a small cohort
df_samples = api.sample_metadata(sample_sets=sample_set)
# Use a small, non-trivial subset of samples (fixed random_state for reproducibility)
df_cohort_samples = df_samples.sample(
n=min(3, len(df_samples)), random_state=0
).reset_index(drop=True)

# Parse region once
region_prepped = _parse_single_region(api, region)

# Method 1: use vectorized method
cohort_results = api.cohort_count_het(
region=region_prepped,
df_cohort_samples=df_cohort_samples,
sample_sets=sample_set,
window_size=window_size,
site_mask=api._default_site_mask,
chunks=base_params.native_chunks,
inline_array=True,
)

# Method 2: compute using the traditional sequential method for comparison
sequential_results = {}

for sample_id in df_cohort_samples["sample_id"]:
df_het = api.sample_count_het(
sample=sample_id,
region=region_prepped,
window_size=window_size,
site_mask=api._default_site_mask,
sample_set=sample_set,
)
sequential_results[sample_id] = df_het["heterozygosity"].values

# Verify both methods produce identical results
for sample_id in df_cohort_samples["sample_id"]:
windows, counts = cohort_results[sample_id]

# Convert cohort counts to heterozygosity
cohort_het = counts / window_size

# Get sequential heterozygosity
sequential_het = sequential_results[sample_id]

# Check shapes match
assert (
len(cohort_het) == len(sequential_het)
), f"Shape mismatch for sample {sample_id}: cohort={len(cohort_het)}, sequential={len(sequential_het)}"

# Check values are numerically identical (within floating point precision)
assert np.allclose(cohort_het, sequential_het, rtol=1e-10), (
f"Values differ for sample {sample_id}. "
f"Max difference: {np.max(np.abs(cohort_het - sequential_het))}"
)
Loading