Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/hla_algorithm/hla_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __init__(
@classmethod
def use_config(
cls,
standards_path: Optional[str] = None,
frequencies_path: Optional[str] = None,
standards_path: str | Path | None = None,
frequencies_path: str | Path | None = None,
) -> "HLAAlgorithm":
"""
An alternate constructor that accepts file paths for the configuration.
Expand All @@ -90,11 +90,11 @@ def use_config(
frequencies: Optional[dict[HLA_LOCUS, dict[HLAProteinPair, int]]] = None

if standards_path is not None:
with open(standards_path) as f:
with Path(standards_path).open() as f:
processed_stds = cls.read_hla_standards(f)

if frequencies_path is not None:
with open(frequencies_path) as f:
with Path(frequencies_path).open() as f:
frequencies = cls.read_hla_frequencies(f)

return cls(processed_stds, frequencies)
Expand Down Expand Up @@ -138,9 +138,9 @@ def load_default_hla_standards() -> LoadedStandards:
:return: List of known HLA standards
:rtype: list[HLAStandard]
"""
with open(
with (
HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_standards.yaml"
) as standards_file:
).open() as standards_file:
return HLAAlgorithm.read_hla_standards(standards_file)

FREQUENCY_LOCUS_COLUMNS: dict[HLA_LOCUS, tuple[str, str]] = {
Expand Down Expand Up @@ -202,7 +202,7 @@ def load_default_hla_frequencies() -> dict[HLA_LOCUS, dict[HLAProteinPair, int]]
:rtype: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
"""
hla_freqs: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
with open(HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_frequencies.csv") as f:
with (HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_frequencies.csv").open() as f:
hla_freqs = HLAAlgorithm.read_hla_frequencies(f)
return hla_freqs

Expand Down
12 changes: 9 additions & 3 deletions src/hla_algorithm/interpret_from_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import argparse
import json
import logging
import sys
from pathlib import Path

from .hla_algorithm import HLAAlgorithm
from .interpret_from_json_lib import HLAInput, HLAResult
Expand All @@ -17,14 +19,18 @@ def main():
)
parser.add_argument(
"infile",
type=argparse.FileType("r"),
type=str,
help='Input file containing the JSON input (use "-" to read from stdin)',
)
args: argparse.Namespace = parser.parse_args()

hla_input_str: str = ""
with args.infile:
for line in args.infile:
if args.infile == "-":
input_file = sys.stdin
else:
input_file = Path(args.infile).open()
with input_file:
for line in input_file:
hla_input_str += f"{line}\n"

hla_input: HLAInput = HLAInput(**json.loads(hla_input_str))
Expand Down
16 changes: 9 additions & 7 deletions src/hla_algorithm/interpret_from_json_lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional
from pathlib import Path

from pydantic import BaseModel, Field

Expand All @@ -24,11 +24,11 @@

class HLAInput(BaseModel):
seq1: str
seq2: Optional[str]
seq2: str | None
locus: HLA_LOCUS
threshold: Optional[int] = None
hla_std_path: Optional[str] = None
hla_freq_path: Optional[str] = None
threshold: int | None = None
hla_std_path: Path | None = None
hla_freq_path: Path | None = None

def check_sequences(self) -> list[str]:
errors: list[str] = []
Expand Down Expand Up @@ -113,7 +113,7 @@ class HLAResult(BaseModel):
alleles_version: str = ""
alleles_last_updated: datetime = Field(default_factory=datetime.now)
b5701: bool = False
dist_b5701: Optional[int] = None
dist_b5701: int | None = None
errors: list[str] = Field(default_factory=list)
all_mismatches: dict[str, HLAMatchAdaptor] = Field(default_factory=dict)

Expand Down Expand Up @@ -144,7 +144,9 @@ def build_from_interpretation(

return HLAResult(
seqs=seqs,
alleles_all=[f"{x[0]} - {x[1]}" for x in sort_allele_pairs(aps.allele_pairs)],
alleles_all=[
f"{x[0]} - {x[1]}" for x in sort_allele_pairs(aps.allele_pairs)
],
alleles_clean=alleles_clean,
alleles_for_mismatches=f"{rep_ap[0]} - {rep_ap[1]}",
mismatches=[str(x) for x in match_details.mismatches],
Expand Down
15 changes: 8 additions & 7 deletions src/hla_algorithm/reformat_old_alleles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import csv
import logging
from datetime import datetime
from pathlib import Path
from typing import cast

import yaml

from .utils import (
GroupedAllele,
HLA_LOCUS,
GroupedAllele,
HLARawStandard,
StoredHLAStandards,
group_identical_alleles,
Expand All @@ -28,22 +29,22 @@ def main():
parser.add_argument(
"a_standards",
help="CSV file containing all HLA-A alleles",
type=str,
type=Path,
)
parser.add_argument(
"b_standards",
help="CSV file containing all HLA-B alleles",
type=str,
type=Path,
)
parser.add_argument(
"c_standards",
help="CSV file containing all HLA-C alleles",
type=str,
type=Path,
)
parser.add_argument(
"--output",
help="filename to store the reformatted standards in YAML",
type=str,
type=Path,
default="reformatted_hla_standards.yaml",
)
parser.add_argument(
Expand Down Expand Up @@ -84,7 +85,7 @@ def main():
grouped_alleles: dict[HLA_LOCUS, list[GroupedAllele]] = {"A": [], "B": [], "C": []}
for locus in ("A", "B", "C"):
logger.info(f"Grouping HLA-{locus} alleles....")
with open(input_filenames_by_locus[locus]) as f:
with input_filenames_by_locus[locus].open() as f:
standards_csv: csv.DictReader = csv.DictReader(
f,
fieldnames=("allele", "exon2", "exon3"),
Expand Down Expand Up @@ -114,7 +115,7 @@ def main():
)

logger.info(f"Writing HLA standards to {args.output}....")
with open(args.output, "w") as f:
with args.output.open("w") as f:
yaml.safe_dump(standards_for_saving.model_dump(), f)

logger.info("Done.")
Expand Down
18 changes: 8 additions & 10 deletions src/hla_algorithm/update_alleles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
from datetime import datetime
from io import StringIO
from pathlib import Path
from typing import Final, Optional, TypedDict, cast

import Bio
Expand Down Expand Up @@ -215,13 +216,13 @@ def main():
parser.add_argument(
"--output",
help="filename to store the unreduced standards (YAML format)",
type=str,
type=Path,
default="hla_standards.yaml",
)
parser.add_argument(
"--checksum",
help="filename to store the MD5 checksum of the retrieved data in",
type=str,
type=Path,
default="hla_nuc.fasta.checksum.txt",
)
parser.add_argument(
Expand All @@ -242,8 +243,7 @@ def main():
parser.add_argument(
"--dump_full_fasta_to",
help="if specified, the full original FASTA file is dumped to the specified path",
type=str,
default="",
type=Path,
)
parser.add_argument(
"--standard_report_interval",
Expand Down Expand Up @@ -278,16 +278,14 @@ def main():
f"{retrieval_datetime}."
)

if args.dump_full_fasta_to != "":
if args.dump_full_fasta_to is not None:
logger.info(f"Dumping the full FASTA file to {args.dump_full_fasta_to}.")
with open(args.dump_full_fasta_to, "w") as f:
f.write(alleles_str)
args.dump_full_fasta_to.write_text(alleles_str)

# Compute the checksum.
md5_calc = hashlib.md5()
md5_calc.update(alleles_str.encode())
with open(args.checksum, "w") as f:
f.write(f"{md5_calc.hexdigest()} {HLA_ALLELES_FILENAME}\n")
args.checksum.write_text(f"{md5_calc.hexdigest()} {HLA_ALLELES_FILENAME}\n")

raw_standards: dict[HLA_LOCUS, list[HLARawStandard]] = collate_standards(
list(Bio.SeqIO.parse(StringIO(alleles_str), "fasta")),
Expand All @@ -313,7 +311,7 @@ def main():

# First, prepare the unreduced YAML output.
logger.info(f"Writing HLA standards to {args.output}....")
with open(args.output, "w") as f:
with args.output.open("w") as f:
yaml.safe_dump(standards_for_saving.model_dump(), f)

logger.info("Done.")
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/measure_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def main():
}
)

with open(args.output_csv, "w") as f:
with args.output_csv.open("w") as f:
resource_summary_writer = csv.DictWriter(
f,
fieldnames=("sample_name", "wall_clock_time", "max_memory_usage_kb"),
Expand Down
Loading