Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/docs/usage/wandb_sweeps.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# How to use W&B Sweeps with ModelGenerator for hyperparameter tuning

## Caveats
W&B agents cannot launch multi-node training jobs, which causes great difficulties integrating W&B Sweeps with ModelGenerator. This guide is based on a hacky workaround that introduces many limitations.

### The workaround
An agent is configured to exit immediately after retrieving the next set of hyperparamenters and outputing the complete training command to stdout. This command is then executed on each node without being monitored by an active agent.

### Limitations
1. All agent functionalities are lost. It is not possible to use agent to start/stop/resume/update training runs. Users must manually terminate training runs or implement early-stopping mechanisms.
2. Failed runs have to be re-run manually using your own sbatch scripts. The command for that run is availale in stdout of the failed run.
3. Parameter importance plots use wrong parameters by default, it can be manually fixed by selecting the right parameter names in your mgen config.

>**NOTE**: Before proceeding, please make sure that your training job uses **WandbLogger**.
## SLURM
### Step 1: create a wandb sweep
The default `slurm_sweep.yaml` creates a wandb sweep with the training command `mgen fit --config .local/test.yaml` under the project `autotune-test`. Please modify it to suit your experiments. Key values to change are **project**, **command** and **parameters**.

Run the following command to create a wandb sweep:
```bash
wandb sweep scripts/wandb_sweep/slurm_sweep.yaml
```
Take a note of your sweep ID for step 2. It looks like `<entity>/<project>/<id>` and is found in the output: `wandb: Run sweep agent with: wandb agent`
### Step 2: submit the next training job to SLURM
Similar to step 1, you need to edit `slurm_agent.sh` for your experiment. The most important changes are **WANDB_PROJECT** and **SWEEP_ID**.

The following command creates one sweep agent that runs training with the next set of hyperparamenters.
```bash
sbatch scripts/wandb_sweep/slurm_agent.sh
```

>**TIPS**: To queue your other sweep runs, use `sbatch --dependency`. To launch your other sweep runs in parallel, use `sbatch --array=1-X` where `X` is the number of parallel runs.
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ nav:
- usage/exporting_models.md
- usage/reproducing_experiments.md
- usage/embedding_caching.md
- usage/wandb_sweeps.md
- Tutorials:
- tutorials/kfold_cross_validation.md
- tutorials/finetuning_scheduler.md
Expand Down
2 changes: 2 additions & 0 deletions experiments/AIDO.Cell/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
model = model.to(device).to(torch.float16)

adata = ad.read_h5ad('../../modelgenerator/cell-downstream-tasks/zheng/zheng_train.h5ad')
if not adata.obs_names.is_unique:
adata.obs_names_make_unique()

batch_np = adata[:batch_size].X.toarray()
batch_tensor = torch.from_numpy(batch_np).to(torch.float16).to(device)
Expand Down
1 change: 1 addition & 0 deletions experiments/AIDO.Cell/tutorial_target_id.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
"sc.pp.log1p(adata)\n",
"\n",
"# Clustering + UMAP\n",
"sc.pp.pca(adata)\n",
"sc.pp.neighbors(adata)\n",
"sc.tl.leiden(adata, flavor='igraph', n_iterations=2, resolution=0.5)\n",
"sc.tl.umap(adata)\n",
Expand Down
4 changes: 4 additions & 0 deletions huggingface/aido.cell/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
print("Loading input data...")
try:
adata = ad.read_h5ad(INPUT_FILE)
# Check for duplicate obs_names
if not adata.obs_names.is_unique:
print("⚠ Duplicate AnnData obs_names detected. Automatically applying obs_names_make_unique().")
adata.obs_names_make_unique()
print(f"✓ Loaded data with {adata.n_obs} cells and {adata.n_vars} genes\n")
except Exception as e:
print(f"Error loading data: {e}")
Expand Down
4 changes: 4 additions & 0 deletions huggingface/aido.cell/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
print("Loading input data...")
try:
adata = ad.read_h5ad(INPUT_FILE)
# Check for duplicate obs_names
if not adata.obs_names.is_unique:
print("⚠ Duplicate AnnData obs_names detected. Automatically applying obs_names_make_unique().")
adata.obs_names_make_unique()
print(f"✓ Loaded data with {adata.n_obs} cells and {adata.n_vars} genes\n")
except Exception as e:
print(f"Error loading data: {e}")
Expand Down
105 changes: 0 additions & 105 deletions modelgenerator/backbones/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,27 +146,6 @@ def setup(self):
for name, param in self.encoder.named_parameters():
param.requires_grad = False

def process_batch(self, batch, device, add_special_tokens=True, **kwargs):
"""Processes a batch of sequences to model input format.

Args:
batch (List[str]): List of input sequences.
device (torch.device): Device to move the data to.
add_special_tokens (bool, optional): Whether to add special tokens. Defaults to True.

Returns:
Dict: A dictionary containing required args for forward pass.
"""
seq_tokenized = self.tokenize(
batch["sequences"], padding=True, add_special_tokens=add_special_tokens, **kwargs
)
for k, v in seq_tokenized.items():
if v is not None:
if torch.is_tensor(v):
seq_tokenized[k] = v.to(dtype=torch.long, device=device)
else:
seq_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)
return seq_tokenized

def forward(
self,
Expand Down Expand Up @@ -1555,28 +1534,6 @@ def get_decoder(self) -> nn.Module:
"""
return _Identity()

def process_batch(
self, batch: dict, device: torch.device, add_special_tokens: bool = True, **kwargs
):
"""Processes a batch of sequences to model input format.

Args:
batch (dict): A dictionary containing input sequences.
device (torch.device): Device to move the data to.

Returns:
Dict: A dictionary containing required args for forward pass.
"""
seq_tokenized = self.tokenize(
batch["sequences"], padding=True, add_special_tokens=add_special_tokens, **kwargs
)
for k, v in seq_tokenized.items():
if v is not None:
if torch.is_tensor(v):
seq_tokenized[k] = v.to(dtype=torch.long, device=device)
else:
seq_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)
return seq_tokenized

def tokenize(
self,
Expand Down Expand Up @@ -1821,28 +1778,6 @@ def get_decoder(self) -> nn.Module:
"""
return _Identity()

def process_batch(
self, batch: dict, device: torch.device, add_special_tokens: bool = True, **kwargs
):
"""Processes a batch of sequences to model input format.

Args:
batch (dict): A dictionary containing input sequences.
device (torch.device): Device to move the data to.

Returns:
Dict: A dictionary containing required args for forward pass.
"""
seq_tokenized = self.tokenize(
batch["sequences"], padding=True, add_special_tokens=add_special_tokens, **kwargs
)
for k, v in seq_tokenized.items():
if v is not None:
if torch.is_tensor(v):
seq_tokenized[k] = v.to(dtype=torch.long, device=device)
else:
seq_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)
return seq_tokenized

def tokenize(self, sequences: list[str], **kwargs) -> dict:
"""Tokenizes a list of sequences
Expand Down Expand Up @@ -2033,17 +1968,6 @@ def get_decoder(self) -> nn.Module:
"""
return self.decoder

def process_batch(self, batch: dict, device: torch.device, **kwargs):
"""Processes a batch of sequences to model input format.

Args:
batch (dict): A dictionary containing input sequences.

Returns:
Dict: A dictionary containing required args for forward pass.
"""
input_ids = self.tokenize(batch["sequences"])["input_ids"].to(device=device)
return {"input_ids": input_ids}

def tokenize(
self,
Expand Down Expand Up @@ -2230,17 +2154,6 @@ def get_decoder(self) -> nn.Module:
"""
return self.decoder

def process_batch(self, batch: dict, device: torch.device, **kwargs):
"""Processes a batch of sequences to model input format.

Args:
batch (dict): A dictionary containing input sequences.

Returns:
Dict: A dictionary containing required args for forward pass.
"""
input_ids = self.tokenize(batch["sequences"])["input_ids"].to(device=device)
return {"input_ids": input_ids}

def tokenize(
self,
Expand Down Expand Up @@ -2494,24 +2407,6 @@ def get_decoder(self) -> nn.Module:
"""
return self.decoder

def process_batch(self, batch: dict, device: torch.device, **kwargs):
"""Processes a batch of sequences to model input format.

Args:
batch (dict): A dictionary containing input sequences.
device (torch.device): Device to move the data to.

Returns:
Dict: A dictionary containing required args for forward pass.
"""
seq_tokenized = self.tokenize(batch["sequences"])
for k, v in seq_tokenized.items():
if v is not None:
if torch.is_tensor(v):
seq_tokenized[k] = v.to(dtype=torch.long, device=device)
else:
seq_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)
return seq_tokenized

def tokenize(self, sequences: list[str], **kwargs) -> dict:
"""Tokenizes a list of sequences
Expand Down
11 changes: 10 additions & 1 deletion modelgenerator/backbones/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,16 @@ def process_batch(
Returns:
Dict: A dictionary containing required args for forward pass.
"""
raise NotImplementedError
seq_tokenized = self.tokenize(
batch["sequences"], padding=True, add_special_tokens=add_special_tokens, **kwargs
)
for k, v in seq_tokenized.items():
if v is not None:
if torch.is_tensor(v):
seq_tokenized[k] = v.to(dtype=torch.long, device=device)
else:
seq_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)
return seq_tokenized

def required_data_columns(self) -> List[str]:
"""List of required data columns for the model.
Expand Down
21 changes: 21 additions & 0 deletions modelgenerator/cell/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,29 @@
import bionty as bt
import numpy as np
import pandas as pd
import logging
from lightning.pytorch.utilities import rank_zero_info

logger = logging.getLogger(__name__)


def _ensure_unique_obs_names(adata: ad.AnnData) -> ad.AnnData:
"""Ensures that the observation names in the AnnData object are unique.

Args:
adata (ad.AnnData): The input AnnData object.

Returns:
ad.AnnData: The AnnData object with unique observation names.
"""
if not adata.obs_names.is_unique:
logger.warning(
"Duplicate AnnData obs_names detected. "
"Automatically applying obs_names_make_unique()."
)
adata.obs_names_make_unique()
return adata


def build_map(gene_symbols):
# Get map of symbols to Ensembl IDs:
Expand Down
12 changes: 11 additions & 1 deletion modelgenerator/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,17 +1556,21 @@ def provided_columns(self) -> List[str]:
if self.filter_columns is not None:
return ["sequences"] + self.filter_columns
adata = ad.read_h5ad(os.path.join(self.path, self.trainfile), backed="r")
adata = cell_utils._ensure_unique_obs_names(adata)
return ["sequences"] + list(adata.obs.columns)

def setup(self, stage: Optional[str] = None):
"""Set up the data module by loading the whole datasets and splitting them into training, validation, and test sets."""
adata_train = ad.read_h5ad(os.path.join(self.path, self.trainfile))
adata_train = cell_utils._ensure_unique_obs_names(adata_train)
adata_train = cell_utils.map_gene_symbols(adata_train, symbol_field="gene_symbols")
adata_train = cell_utils.align_genes(adata_train, self.backbone_gene_list)
adata_val = ad.read_h5ad(os.path.join(self.path, self.valfile))
adata_val = cell_utils._ensure_unique_obs_names(adata_val)
adata_val = cell_utils.map_gene_symbols(adata_val, symbol_field="gene_symbols")
adata_val = cell_utils.align_genes(adata_val, self.backbone_gene_list)
adata_test = ad.read_h5ad(os.path.join(self.path, self.testfile))
adata_test = cell_utils._ensure_unique_obs_names(adata_test)
adata_test = cell_utils.map_gene_symbols(adata_test, symbol_field="gene_symbols")
adata_test = cell_utils.align_genes(adata_test, self.backbone_gene_list)

Expand Down Expand Up @@ -1840,6 +1844,7 @@ def provided_columns(self) -> List[str]:
def setup(self, stage: Optional[str] = None):
"""Set up the data module by loading the whole datasets and splitting them into training, validation, and test sets."""
adata = ad.read_h5ad(os.path.join(self.path, self.trainfile))
adata = cell_utils._ensure_unique_obs_names(adata)
adata = cell_utils.align_genes(adata, self.backbone_gene_list, ensembl_field="feature_id")

adata_train = adata[adata.obs[self.split_column] == "train"]
Expand Down Expand Up @@ -1913,6 +1918,7 @@ def __init__(

# Load metadata in backed mode
self.adata = ad.read_h5ad(self.file_path, backed="r")
self.adata = cell_utils._ensure_unique_obs_names(self.adata)
self._process_metadata()
self.neighbor_indices = self._precompute_neighbors()
self.length = len(self.adata.obs)
Expand Down Expand Up @@ -2176,13 +2182,17 @@ def provided_columns(self) -> List[str]:
if self.filter_columns is not None:
return ["sequences"] + self.filter_columns
adata = ad.read_h5ad(os.path.join(self.path, self.file), backed="r")
adata = cell_utils._ensure_unique_obs_names(adata)
return ["sequences"] + list(adata.obs.columns)

def setup(self, stage: Optional[str] = None):
"""Set up the data module by loading the whole datasets and splitting them into training, validation, and test sets."""
rank_zero_info("***")
rank_zero_info(f"loading {self.file}")
adata = ad.read_h5ad(os.path.join(self.path, self.file))
# Check for duplicate obs_names
if not adata.obs_names.is_unique:
rank_zero_info("⚠ Duplicate AnnData obs_names detected. Automatically applying obs_names_make_unique().")
adata.obs_names_make_unique()
adata = cell_utils.map_gene_symbols(adata, symbol_field="index")
adata = cell_utils.align_genes(adata, self.backbone_gene_list)
rank_zero_info(f"loaded {adata.shape[0]} cells")
Expand Down
9 changes: 8 additions & 1 deletion modelgenerator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,25 @@

class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.init_args.batch_size", "model.init_args.batch_size")
parser.link_arguments(
"data.init_args.batch_size",
"model.init_args.batch_size",
apply_on="instantiate",
)
parser.link_arguments(
"model.init_args.backbone.class_path",
"trainer.strategy.init_args.auto_wrap_policy.init_args.backbone_classes",
apply_on="instantiate",
)
parser.link_arguments(
"model.init_args.backbone.class_path",
"data.init_args.backbone_class_path",
apply_on="instantiate",
)
parser.link_arguments(
"model.init_args.backbone.init_args.use_peft",
"trainer.strategy.init_args.auto_wrap_policy.init_args.use_peft",
apply_on="instantiate",
)
parser.link_arguments(
"data",
Expand Down
1 change: 1 addition & 0 deletions modelgenerator/prot_inv_fold/pif_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
self.accuracies_str_enc = []
self.acc_metric = MyAccuracy()

@once_only
def configure_model(self) -> None:
self.lm = self.backbone_fn(None, None)
self.tokenizer = self.lm.tokenizer
Expand Down
1 change: 1 addition & 0 deletions modelgenerator/rna_inv_fold/rif_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(

print(self)

@once_only
def configure_model(self) -> None:
self.lm = self.backbone_fn(None, None)
self.lm.setup()
Expand Down
1 change: 1 addition & 0 deletions modelgenerator/rna_ss/rna_ss_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
self.THRESHOLD_TUNE_METRIC = "f1"
self.THRESHOLD_CANDIDATES = [i / 100 for i in range(1, 30, 1)]

@once_only
def configure_model(self) -> None:
self.backbone.setup()
if self.use_legacy_adapter:
Expand Down
Loading