Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/decima/cli/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
@click.option("--logger", default="wandb", type=str, help="Logger.")
@click.option("--num-workers", default=16, type=int, help="Number of workers.")
@click.option("--seed", default=0, type=int, help="Random seed.")
@click.option("--checkpoint", default=None, type=str, help="Path to a checkpoint to resume training from.")
def cli_finetune(
name,
model,
Expand All @@ -63,6 +64,7 @@ def cli_finetune(
logger,
num_workers,
seed,
checkpoint,
):
"""Finetune the Decima model.

Expand Down Expand Up @@ -102,8 +104,11 @@ def cli_finetune(
)
val_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="val", max_seq_shift=0)

if isinstance(device, str) and device.isdigit():
device = int(device)
if isinstance(device, str):
if "," in device:
device = [int(d) for d in device.split(",")]
elif device.isdigit():
device = int(device)

train_params = {
"batch_size": batch_size,
Expand Down Expand Up @@ -137,7 +142,7 @@ def cli_finetune(
run = wandb.init(project="decima", dir=name, name=name)

logger.info("Training")
model.train_on_dataset(train_dataset, val_dataset)
model.train_on_dataset(train_dataset, val_dataset, checkpoint_path=checkpoint)
train_dataset.close()
val_dataset.close()
if logger == "wandb":
Expand Down
75 changes: 52 additions & 23 deletions src/decima/data/write_hdf5.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
import h5py
import numpy as np
from grelu.sequence.format import convert_input_type
from grelu.io.genome import get_genome
from grelu.sequence.format import _BASE_LUT
from grelu.sequence.utils import get_unique_length
from tqdm import tqdm


def write_hdf5(file, ad, pad=0, genome="hg38"):
def write_hdf5(file, ad, pad=0, genome="hg38", batch_size=1000):
"""Write AnnData object to HDF5 file.

Args:
file: Path to the HDF5 file to write
ad: AnnData object containing the data
pad: Amount of padding to add. Defaults to 0
genome: Genome name or path to the genome fasta file. Defaults to "hg38"
batch_size: Number of genes per write batch. Defaults to 1000
"""
# Calculate seq_len
seq_len = get_unique_length(ad.var)
padded_seq_len = seq_len + 2 * pad
n_genes = ad.var.shape[0]
genome_obj = get_genome(genome)

intervals = ad.var[["chrom", "start", "end", "strand"]].copy()
intervals["start"] = intervals["start"] - pad
intervals["end"] = intervals["end"] + pad

with h5py.File(file, "w") as f:
# Metadata
print("Writing metadata")
f.create_dataset("pad", shape=(), data=pad)
f.create_dataset("seq_len", shape=(), data=seq_len)
padded_seq_len = seq_len + 2 * pad
f.create_dataset("padded_seq_len", shape=(), data=padded_seq_len)

# Tasks
Expand All @@ -35,26 +43,47 @@ def write_hdf5(file, ad, pad=0, genome="hg38"):
f.create_dataset("genes", shape=arr.shape, data=arr)

# Labels
arr = np.expand_dims(ad.X.T.astype(np.float32), 2)
print(f"Writing labels array of shape: {arr.shape}")
print("Writing labels")
X = ad.X.toarray() if hasattr(ad.X, "toarray") else np.asarray(ad.X)
arr = np.expand_dims(X.T.astype(np.float32), 2)
print(f" shape: {arr.shape}")
f.create_dataset("labels", shape=arr.shape, dtype=np.float32, data=arr)
del X, arr

# Masks and sequences — written in batches to avoid OOM
print("Writing masks and sequences")
masks_ds = f.create_dataset(
"masks", shape=(n_genes, padded_seq_len), dtype=np.float32
)
seqs_ds = f.create_dataset(
"sequences", shape=(n_genes, padded_seq_len), dtype=np.int8
)

n_batches = (n_genes + batch_size - 1) // batch_size
for b in tqdm(range(n_batches), desc="Batches"):
start_i = b * batch_size
end_i = min(start_i + batch_size, n_genes)
batch_var = ad.var.iloc[start_i:end_i]
batch_iv = intervals.iloc[start_i:end_i]

masks = np.zeros((end_i - start_i, padded_seq_len), dtype=np.float32)
seqs = np.empty((end_i - start_i, padded_seq_len), dtype=np.int8)

for j, (row_var, row_iv) in enumerate(
zip(batch_var.itertuples(), batch_iv.itertuples())
):
masks[j, row_var.gene_mask_start + pad : row_var.gene_mask_end + pad] = 1.0
seq = str(
genome_obj.get_seq(
row_iv.chrom,
row_iv.start + 1,
row_iv.end,
rc=row_iv.strand == "-",
)
).upper()
seqs[j] = _BASE_LUT[np.frombuffer(seq.encode("ascii"), dtype=np.uint8)]

# Gene masks
print("Making gene masks")
shape = (ad.var.shape[0], padded_seq_len)
arr = np.zeros(shape=shape)
for i, row in enumerate(ad.var.itertuples()):
arr[i, row.gene_mask_start + pad : row.gene_mask_end + pad] += 1
print(f"Writing mask array of shape: {arr.shape}")
f.create_dataset("masks", shape=shape, dtype=np.float32, data=arr)

# Sequences
print("Encoding sequences")
arr = ad.var[["chrom", "start", "end", "strand"]].copy()
arr.start = arr.start - pad
arr.end = arr.end + pad
arr = convert_input_type(arr, "indices", genome=genome)
print(f"Writing sequence array of shape: {arr.shape}")
f.create_dataset("sequences", shape=arr.shape, dtype=np.int8, data=arr)
masks_ds[start_i:end_i] = masks
seqs_ds[start_i:end_i] = seqs

print("Done!")
6 changes: 4 additions & 2 deletions src/decima/model/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,12 +313,14 @@ def train_on_dataset(
logger = self.parse_logger()

# Set up trainer
devices = make_list(self.train_params["devices"])
trainer = pl.Trainer(
max_epochs=self.train_params["max_epochs"],
accelerator="gpu",
devices=make_list(self.train_params["devices"]),
devices=devices,
strategy="ddp" if len(devices) > 1 else "auto",
logger=logger,
callbacks=[ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=self.train_params["save_top_k"])],
callbacks=[ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=self.train_params["save_top_k"], save_last=True)],
default_root_dir=self.train_params["save_dir"],
accumulate_grad_batches=self.train_params["accumulate_grad_batches"],
gradient_clip_val=self.train_params["clip"],
Expand Down
Loading