Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/grelu/io/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,18 @@ def read_sizes(genome: str = "hg38") -> pd.DataFrame:

Args:
genome: Either a genome name to load from genomepy,
or the path to a chromosome sizes file.
the path to a chromosome sizes file (tab-separated, no header),
or the path to a local FASTA file (expects a .sizes file alongside it).

Returns:
A dataframe containing columns "chrom" (chromosome names)
and "size" (chromosome size).
"""
# Get file path
# If the argument is an existing file, treat it as a sizes file directly
if os.path.isfile(genome):
return pd.read_table(
genome, header=None, names=["chrom", "size"], dtype={"chrom": str, "size": int}
)
genome = get_genome(genome).sizes_file
return pd.read_table(
genome, header=None, names=["chrom", "size"], dtype={"chrom": str, "size": int}
Expand Down
1 change: 1 addition & 0 deletions src/grelu/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from grelu.lightning import LightningModel
from grelu.resources.utils import get_meme_file_path, get_blacklist_file
from grelu.resources.wandb import DEFAULT_WANDB_HOST, get_artifact


class DeprecationError(Exception):
Expand Down
19 changes: 12 additions & 7 deletions src/grelu/sequence/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from pandas.api.types import is_categorical_dtype, is_integer_dtype, is_string_dtype
from torch import Tensor

Expand All @@ -28,6 +29,12 @@

INDEX_TO_BASE_HASH: Dict[int, str] = {i: base for i, base in enumerate(ALLOWED_BASES)}

# Vectorized ASCII lookup table: byte value → base index (unknown chars → N=4)
_BASE_LUT = np.full(256, 4, dtype=np.int8)
for _i, _b in enumerate(ALLOWED_BASES):
_BASE_LUT[ord(_b)] = _i
_BASE_LUT[ord(_b.lower())] = _i


def check_intervals(df: pd.DataFrame) -> bool:
"""
Expand Down Expand Up @@ -197,8 +204,9 @@ def intervals_to_strings(

else:
# Extract sequences for multiple intervals
tqdm.pandas(desc="Fetching sequences")
if "strand" in intervals.columns:
seqs = intervals.apply(
seqs = intervals.progress_apply(
lambda row: str(
genome.get_seq(
row["chrom"],
Expand All @@ -210,7 +218,7 @@ def intervals_to_strings(
axis=1,
).tolist()
else:
seqs = intervals.apply(
seqs = intervals.progress_apply(
lambda row: str(
genome.get_seq(row["chrom"], row["start"] + 1, row["end"])
).upper(),
Expand Down Expand Up @@ -241,7 +249,7 @@ def strings_to_indices(

# Convert a single sequence
if isinstance(strings, str):
arr = np.array([BASE_TO_INDEX_HASH[base] for base in strings], dtype=np.int8)
arr = _BASE_LUT[np.frombuffer(strings.encode("ascii"), dtype=np.uint8)]
if add_batch_axis:
return np.expand_dims(arr, 0)
else:
Expand All @@ -253,10 +261,7 @@ def strings_to_indices(
strings
), "All input sequences must have the same length."
return np.stack(
[
np.array([BASE_TO_INDEX_HASH[base] for base in string], dtype=np.int8)
for string in strings
]
[_BASE_LUT[np.frombuffer(s.encode("ascii"), dtype=np.uint8)] for s in strings]
)


Expand Down