Skip to content
204 changes: 204 additions & 0 deletions malariagen_data/anoph/cnv_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import warnings
from bisect import bisect_left, bisect_right
from typing import Dict, List, Optional, Tuple, Union

import dask
import dask.array as da
import numba
import numpy as np
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -689,6 +692,162 @@ def cnv_discordant_read_calls(

return ds

@_check_types
@doc(
summary="""
Compute modal copy number by gene, from HMM data.
""",
returns="""
A dataset of modal copy number per gene and associated data.
""",
)
def gene_cnv(
self,
region: base_params.regions,
sample_sets: Optional[base_params.sample_sets] = None,
sample_query: Optional[base_params.sample_query] = None,
sample_query_options: Optional[base_params.sample_query_options] = None,
max_coverage_variance: cnv_params.max_coverage_variance = cnv_params.max_coverage_variance_default,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> xr.Dataset:
regions: List[Region] = _parse_multi_region(self, region)
del region

ds = _simple_xarray_concat(
[
self._gene_cnv(
region=r,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
max_coverage_variance=max_coverage_variance,
chunks=chunks,
inline_array=inline_array,
)
for r in regions
],
dim="genes",
)

return ds

def _gene_cnv(
self,
*,
region,
sample_sets,
sample_query,
sample_query_options,
max_coverage_variance,
chunks,
inline_array,
):
# Sanity check.
if not isinstance(region, Region):
raise TypeError(
f"Expected region to be a Region object, "
f"got {type(region).__name__}: {region!r}"
)

# Access genes within the region of interest.
df_genome_features = self.genome_features(region=region)
sample_query_options = sample_query_options or {}
df_genes = df_genome_features.query(
f"type == '{self._gff_gene_type}'", **sample_query_options
)

# Handle empty case: raise clear error if no genes found in region.
if len(df_genes) == 0:
raise ValueError(
f"No genes of type '{self._gff_gene_type}' found in region {region}. "
f"Cannot compute gene CNV without genes."
)

# Refine the region for CNV data to ensure coverage of all requested genes.
cnv_region = Region(
region.contig, df_genes["start"].min(), df_genes["end"].max()
)
Comment thread
kunal-10-cloud marked this conversation as resolved.

# Access HMM data.
ds_hmm = self.cnv_hmm(
region=cnv_region,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
max_coverage_variance=max_coverage_variance,
chunks=chunks,
inline_array=inline_array,
)
pos = ds_hmm["variant_position"].data
end = ds_hmm["variant_end"].data
cn = ds_hmm["call_CN"].data.astype("int8", casting="same_kind")
with self._dask_progress(desc="Load CNV HMM data"):
pos, end, cn = dask.compute(pos, end, cn)

# Set up intermediates.
windows = []
modes = []
counts = []

# Iterate over genes.
genes_iterator = self._progress(
df_genes.itertuples(),
desc="Compute modal gene copy number",
total=len(df_genes),
)
for gene in genes_iterator:
# Locate windows overlapping the gene.
loc_gene_start = bisect_left(end, gene.start)
loc_gene_stop = bisect_right(pos, gene.end)
w = loc_gene_stop - loc_gene_start
windows.append(w)

# Slice out copy number data for the given gene.
cn_gene = cn[loc_gene_start:loc_gene_stop]

# Compute the modes.
m, c = _cn_mode(cn_gene, vmax=12)
modes.append(m)
counts.append(c)

# Combine results.
windows = np.array(windows)
modes = np.vstack(modes)
counts = np.vstack(counts)

# Build dataset.
ds_out = xr.Dataset(
coords={
"gene_id": (["genes"], df_genes["ID"].values),
"sample_id": (["samples"], ds_hmm["sample_id"].values),
},
data_vars={
"gene_contig": (["genes"], df_genes["contig"].values),
"gene_start": (["genes"], df_genes["start"].values),
"gene_end": (["genes"], df_genes["end"].values),
"gene_windows": (["genes"], windows),
"gene_name": (
["genes"],
df_genes[self._gff_gene_name_attribute].values,
),
"gene_strand": (["genes"], df_genes["strand"].values),
"gene_description": (["genes"], df_genes["description"].values),
"CN_mode": (["genes", "samples"], modes),
"CN_mode_count": (["genes", "samples"], counts),
"sample_coverage_variance": (
["samples"],
ds_hmm["sample_coverage_variance"].values,
),
"sample_is_high_variance": (
["samples"],
ds_hmm["sample_is_high_variance"].values,
),
},
)

return ds_out

@_check_types
@doc(
summary="Plot CNV HMM data for a single sample, using bokeh.",
Expand Down Expand Up @@ -1095,3 +1254,48 @@ def plot_cnv_hmm_heatmap(
if show:
bkplt.show(fig)
return fig


@numba.njit("Tuple((int8, int64))(int8[:], int8)")
def _cn_mode_1d(a, vmax):
# setup intermediates
m = a.shape[0]
counts = np.zeros(vmax + 1, dtype=numba.int64)

# initialise return values
mode = numba.int8(-1)
mode_count = numba.int64(0)

# iterate over array values, keeping track of counts
for i in range(m):
v = a[i]
if 0 <= v <= vmax:
c = counts[v]
c += 1
counts[v] = c
if c > mode_count:
mode = v
mode_count = c
elif c == mode_count and v < mode:
# consistency with scipy.stats, break ties by taking lower value
mode = v

return mode, mode_count


@numba.njit("Tuple((int8[:], int64[:]))(int8[:, :], int8)")
def _cn_mode(a, vmax):
# setup intermediates
n = a.shape[1]

# setup outputs
modes = np.zeros(n, dtype=numba.int8)
counts = np.zeros(n, dtype=numba.int64)

# iterate over columns, computing modes
for j in range(a.shape[1]):
mode, count = _cn_mode_1d(a[:, j], vmax)
modes[j] = mode
counts[j] = count

return modes, counts
Loading
Loading