Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ version: 2
build:
os: ubuntu-24.04
tools:
python: "mambaforge-22.9"
python: "miniconda-latest"
jobs:
pre_build:
# Since we have to use an older nf-schema to support older nextflows, we need to update the nextflow_schema file
# for it to output our params doc by replacing defs (needed for older nf-schema) to $defs (newer format)
- conda list
- bash ./scripts/pre_docs_install.sh
- cat nextflow_schema.json
- cat docs/params_doc.md
Expand Down
296 changes: 210 additions & 86 deletions bin/combine_annotations.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,84 @@
#!/usr/bin/env python
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
from skbio.io import read as read_sequence
import os
from pathlib import Path
from utils.logger import get_logger

import click
import polars as pl
from skbio.io import read as read_sequence

from utils.logger import get_logger

FASTA_COLUMN = os.getenv("FASTA_COLUMN", "input_fasta")

logger = get_logger(filename=Path(__file__).stem)


def read_and_preprocess(path: Path):
# We design input fastas from intermediate steps to be named like: "input_fasta___some_information_annotation_file.tsv"
input_fasta = input_fasta_from_filepath(path)
def read_and_preprocess(
path: Path, input_fasta: str, seperator=","
) -> pl.LazyFrame | None:
# We design input fastas from intermediate steps to be named like:
# "input_fasta___some_information_annotation_file.tsv"
# input_fasta = input_fasta_from_filepath(path)
try:
df = pd.read_csv(path)
df[FASTA_COLUMN] = input_fasta # Add input_fasta column
return df
lf = pl.scan_csv(path, separator=seperator).with_columns(
pl.lit(input_fasta).alias(FASTA_COLUMN)
)
# Validate here so we can log file-specific parse failures before the final collect.
lf.collect_schema()
return lf
except Exception as e:
logger.error(f"Error loading DataFrame for input_fasta {input_fasta}: {str(e)}")
return pd.DataFrame() # Return an empty DataFrame in case of error
return None


def input_fasta_from_filepath(file_path: Path):
return file_path.stem.split("___")[0]
def input_fasta_from_filepath(file_path: Path, splitter="___") -> str:
return file_path.stem.split(splitter)[0]


def assign_rank(row):
rank = "E"
if row.get("kegg_bitScore", 0) > 350:
rank = "A"
elif row.get("uniref_bitScore", 0) > 350:
rank = "B"
elif row.get("kegg_bitScore", 0) > 60 or row.get("uniref_bitScore", 0) > 60:
rank = "C"
elif any(row.get(f"{db}_bitScore", 0) > 60 for db in ["pfam", "dbcan", "merops"]):
rank = "D"
return rank
def bit_score_expr(column_name: str) -> pl.Expr:
return pl.col(column_name).fill_null(0)


def convert_bit_scores_to_numeric(df):
for col in df.columns:
if "_bitScore" in col:
df[col] = pd.to_numeric(df[col], errors="coerce")
return df
def assign_rank_expr(columns: list[str]) -> pl.Expr:
kegg_score = (
bit_score_expr("kegg_bitScore") if "kegg_bitScore" in columns else pl.lit(0)
)
uniref_score = (
bit_score_expr("uniref_bitScore") if "uniref_bitScore" in columns else pl.lit(0)
)
motif_checks = [
bit_score_expr(f"{db}_bitScore") > 60
for db in ["pfam", "dbcan", "merops"]
if f"{db}_bitScore" in columns
]
motif_expr = pl.any_horizontal(motif_checks) if motif_checks else pl.lit(False)

return (
pl.when(kegg_score > 350)
.then(pl.lit("A"))
.when(uniref_score > 350)
.then(pl.lit("B"))
.when((kegg_score > 60) | (uniref_score > 60))
.then(pl.lit("C"))
.when(motif_expr)
.then(pl.lit("D"))
.otherwise(pl.lit("E"))
.alias("rank")
)


def convert_bit_scores_to_numeric(lf: pl.LazyFrame) -> pl.LazyFrame:
bit_score_columns = [
col for col in lf.collect_schema().names() if "_bitScore" in col
]
if not bit_score_columns:
return lf
return lf.with_columns(
[
pl.col(col).cast(pl.Float64, strict=False).alias(col)
for col in bit_score_columns
]
)


def count_motifs(gene_faa, motif="(C..CH)", genes_faa_dict=None):
Expand Down Expand Up @@ -89,7 +122,16 @@ def set_gene_data(gene_faa, genes_faa_dict=None):
return genes_faa_dict


def organize_columns(df, special_columns=None):
def genes_dict_to_frame(genes_faa_dict: dict) -> pl.DataFrame:
rows = []
for query_id, values in genes_faa_dict.items():
row = {"query_id": query_id}
row.update(values)
rows.append(row)
return pl.DataFrame(rows) if rows else pl.DataFrame(schema={"query_id": pl.String})


def organize_columns(df: pl.DataFrame, special_columns=None) -> pl.DataFrame:
if special_columns is None:
special_columns = []
base_columns = [
Expand All @@ -114,7 +156,7 @@ def organize_columns(df, special_columns=None):
if col not in base_columns + kegg_columns + special_columns
]

db_prefixes = set(col.split("_")[0] for col in other_columns)
db_prefixes = sorted(set(col.split("_")[0] for col in other_columns))
sorted_other_columns = []
for prefix in db_prefixes:
prefixed_columns = sorted(
Expand All @@ -126,78 +168,160 @@ def organize_columns(df, special_columns=None):
final_columns_order = (
base_columns + kegg_columns + sorted_other_columns + special_columns
)
return df[final_columns_order]
return df.select(final_columns_order)


@click.command()
@click.option("--annotations_dir", required=True, help="Directory of annotation files")
@click.option("--annotations_dir", help="Directory of annotation files")
@click.option("--genes_dir", help="Directory genes faa file paths from prodigal")
@click.option(
"--genes_dir", required=True, help="Directory genes faa file paths from prodigal"
"--dbcan_dir",
help="Directory of run_dbcan hmm_results.tsv and sub_hmm_results.tsv",
)
@click.option("--output", help="Output file path for the combined annotations.")
@click.option(
"--threads", help="Number of threads for parallel processing", type=int, default=4
)
def combine_annotations(annotations_dir, genes_dir, output, threads):
def combine_annotations(annotations_dir, genes_dir, dbcan_dir, output):
"""Combine annotation files with ranks and avoid duplicating specific columns."""
annotations = Path(annotations_dir).glob("*")
genes_faa = Path(genes_dir).glob("*")
with ThreadPoolExecutor(max_workers=threads) as executor:
# futures = [executor.submit(read_and_preprocess, input_fasta, path) for input_fasta, path in input_fastas_and_paths]
futures = [
executor.submit(read_and_preprocess, Path(path)) for path in annotations
]
data_frames = [future.result() for future in as_completed(futures)]

combined_data = pd.concat(data_frames, ignore_index=True)
annotations = sorted(Path(annotations_dir).glob("*")) if annotations_dir else []
genes_faa = sorted(Path(genes_dir).glob("*")) if genes_dir else []
dbcan_paths = (
sorted(Path(dbcan_dir).glob("*dbCAN_hmm_results.tsv")) if dbcan_dir else []
)
dbcan_sub_paths = (
sorted(Path(dbcan_dir).glob("*dbCANsub_hmm_results.tsv")) if dbcan_dir else []
)
annotation_frames = [
frame
for frame in (
read_and_preprocess(path, input_fasta=input_fasta_from_filepath(path))
for path in annotations
)
if frame is not None
]
if annotation_frames:
combined_data_lf = pl.concat(annotation_frames, how="diagonal_relaxed")
else:
combined_data_lf = pl.LazyFrame(
schema={"query_id": pl.String, FASTA_COLUMN: pl.String}
)
if genes_faa:
genes_faa_dict = dict()
genes_faa_dict = {}
for gene_path in genes_faa:
gene_path = str(gene_path)
genes_faa_dict
count_motifs(gene_path, "(C..CH)", genes_faa_dict=genes_faa_dict)
set_gene_data(gene_path, genes_faa_dict)
df = pd.DataFrame.from_dict(genes_faa_dict, orient="index")
columns = [col for col in df.columns.tolist() if col != FASTA_COLUMN]
combined_data = combined_data.drop(columns=columns, errors="ignore")
df.index.name = "query_id"
df = df.rename(columns={FASTA_COLUMN: FASTA_COLUMN + "2"})

# we use outer to get any genes that don't have hits
combined_data = pd.merge(combined_data, df, how="outer", on="query_id")
combined_data[FASTA_COLUMN] = combined_data[FASTA_COLUMN].fillna("")
mask = combined_data[FASTA_COLUMN] != ""
combined_data[FASTA_COLUMN] = combined_data[FASTA_COLUMN].where(
mask, other=combined_data[FASTA_COLUMN + "2"]
gene_lf = pl.LazyFrame(list(genes_faa_dict.values())).with_columns(
query_id=pl.Series(genes_faa_dict.keys())
)
gene_lf_cols = gene_lf.collect_schema().names()
columns = [col for col in gene_lf_cols if col not in (FASTA_COLUMN, "query_id")]
combined_data_lf = combined_data_lf.drop(columns, strict=False)
# Use a full join so genes without hits remain in the output.
combined_data_lf = combined_data_lf.join(
gene_lf, how="full", on="query_id", coalesce=True
)
combined_data_lf = combined_data_lf.with_columns(
pl.when(pl.col(FASTA_COLUMN).is_not_null() & (pl.col(FASTA_COLUMN) != ""))
.then(pl.col(FASTA_COLUMN))
.otherwise(pl.col(FASTA_COLUMN + "_right"))
.alias(FASTA_COLUMN)
).drop(FASTA_COLUMN + "_right", strict=False)
if dbcan_paths:
dbcan_lf = pl.concat(
[
frame
for frame in (
read_and_preprocess(
path,
input_fasta=input_fasta_from_filepath(path, splitter="_dbCAN"),
seperator="\t",
)
for path in dbcan_paths
)
if frame is not None
],
how="diagonal_relaxed",
)
dbcan_lf = dbcan_lf.select(
pl.col(FASTA_COLUMN),
pl.col("Target Name").alias("query_id"),
pl.col("HMM Name").str.strip_suffix(".hmm").alias("dbcan_id"),
pl.col("i-Evalue").alias("dbcan_i_Evalue"),
)
combined_data_lf = combined_data_lf.join(
dbcan_lf, how="full", on="query_id", coalesce=True
)
# TODO: fix the merge so it doesn't make this column
combined_data = combined_data.drop(columns=FASTA_COLUMN + "2")

combined_data = convert_bit_scores_to_numeric(combined_data)

aggregation_functions = {
col: "first"
for col in combined_data.columns
if col not in ["query_id", FASTA_COLUMN]
}
for col in ["Completeness", "Contamination", "taxonomy"]:
if col in combined_data.columns:
aggregation_functions[col] = "max"
combined_data = combined_data.groupby(
["query_id", FASTA_COLUMN], as_index=False
).agg(aggregation_functions)
# After aggregating data
combined_data["rank"] = combined_data.apply(assign_rank, axis=1)

# Continue with organizing columns and saving the DataFrame
combined_data_lf = combined_data_lf.with_columns(
pl.when(pl.col(FASTA_COLUMN).is_not_null() & (pl.col(FASTA_COLUMN) != ""))
.then(pl.col(FASTA_COLUMN))
.otherwise(pl.col(FASTA_COLUMN + "_right"))
.alias(FASTA_COLUMN)
).drop(FASTA_COLUMN + "_right", strict=False)
if dbcan_sub_paths:
dbcan_sub_lf = pl.concat(
[
frame
for frame in (
read_and_preprocess(
path,
input_fasta=input_fasta_from_filepath(path, splitter="_dbCAN"),
seperator="\t",
)
for path in dbcan_sub_paths
)
if frame is not None
],
how="diagonal_relaxed",
)
dbcan_sub_lf = dbcan_sub_lf.select(
pl.col(FASTA_COLUMN),
pl.col("Target Name").alias("query_id"),
pl.col("Subfam Name").alias("dbcan_sub_id"),
pl.col("Subfam Composition").alias("dbcan_sub_composition"),
pl.col("Subfam EC").alias("dbcan_sub_ec"),
pl.col("Substrate").alias("dbcan_sub_substrate"),
pl.col("i-Evalue").alias("dbcan_sub_i_Evalue"),
)
combined_data_lf = combined_data_lf.join(
dbcan_sub_lf, how="full", on="query_id", coalesce=True
)
combined_data_lf = combined_data_lf.with_columns(
pl.when(pl.col(FASTA_COLUMN).is_not_null() & (pl.col(FASTA_COLUMN) != ""))
.then(pl.col(FASTA_COLUMN))
.otherwise(pl.col(FASTA_COLUMN + "_right"))
.alias(FASTA_COLUMN)
).drop(FASTA_COLUMN + "_right", strict=False)

combined_data_lf = convert_bit_scores_to_numeric(combined_data_lf)
all_columns = combined_data_lf.collect_schema().names()
aggregation_exprs = []
for col in all_columns:
if col in ["query_id", FASTA_COLUMN]:
continue
if col in ["Completeness", "Contamination", "taxonomy"]:
aggregation_exprs.append(pl.col(col).max().alias(col))
else:
aggregation_exprs.append(pl.col(col).first(ignore_nulls=True).alias(col))
combined_data_lf = combined_data_lf.group_by(["query_id", FASTA_COLUMN]).agg(
aggregation_exprs
)
combined_data_lf = combined_data_lf.with_columns(
assign_rank_expr(combined_data_lf.collect_schema().names())
)
combined_data = combined_data_lf.collect()

special_columns = ["Completeness", "Contamination", "taxonomy"]
special_columns = [col for col in special_columns if col in combined_data.columns]
combined_data = organize_columns(combined_data, special_columns=special_columns)
combined_data = combined_data.sort_values(
by=[FASTA_COLUMN, "scaffold", "gene_number"]
)

combined_data.to_csv(output, index=False, sep="\t")
sort_columns = [
col
for col in [FASTA_COLUMN, "scaffold", "gene_number"]
if col in combined_data.columns
]
if sort_columns:
combined_data = combined_data.sort(sort_columns)

combined_data.write_csv(output, separator="\t")
logger.info(f"Combined annotations saved to {output}, with corrected gene numbers.")


Expand Down
Loading
Loading