Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions cytetype/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
aggregate_cluster_metadata,
extract_visualization_coordinates,
)
from .preprocessing.validation import materialize_canonical_gene_symbols_column
from .preprocessing.validation import (
materialize_canonical_gene_symbols_column,
_generate_unique_na_label,
)
from .core.payload import build_annotation_payload, save_query_to_file
from .core.artifacts import (
_is_integer_valued,
Expand Down Expand Up @@ -87,6 +90,7 @@ def __init__(
max_metadata_categories: int = 500,
api_url: str = "https://prod.cytetype.nygen.io",
auth_token: str | None = None,
label_na: bool = False,
) -> None:
"""Initialize CyteType with AnnData object and perform data preparation.

Expand Down Expand Up @@ -125,6 +129,11 @@ def __init__(
deployment. Defaults to "https://prod.cytetype.nygen.io".
auth_token (str | None, optional): Bearer token for API authentication. If provided,
will be included in the Authorization header as "Bearer {auth_token}". Defaults to None.
label_na (bool, optional): If True, cells with NaN values in the
``group_key`` column are assigned an ``'Unknown'`` cluster label
(or ``'Unknown 2'``, etc. if that label already exists). The original
AnnData object is not modified. If False (default), a ``ValueError``
is raised instead.

Raises:
KeyError: If the required keys are missing in `adata.obs` or `adata.uns`
Expand Down Expand Up @@ -152,8 +161,40 @@ def __init__(
self._original_gene_symbols_column = self.gene_symbols_column

self.coordinates_key = validate_adata(
adata, group_key, rank_key, self.gene_symbols_column, coordinates_key
adata, group_key, rank_key, self.gene_symbols_column, coordinates_key,
label_na=label_na,
)

if label_na:
nan_mask = adata.obs[group_key].isna()
if nan_mask.any():
n_nan = int(nan_mask.sum())
pct = round(100 * n_nan / adata.n_obs, 1)
existing_labels = set(
str(v) for v in adata.obs[group_key].dropna().unique()
)
na_label = _generate_unique_na_label(existing_labels)
logger.warning(
f"⚠️ Relabeling {n_nan} cells ({pct}%) with NaN values "
f"in '{group_key}' as '{na_label}'."
)
adata = anndata.AnnData(
X=adata.X,
obs=adata.obs.copy(),
var=adata.var,
uns=adata.uns,
obsm=adata.obsm,
varm=adata.varm,
layers=adata.layers,
obsp=adata.obsp,
varp=adata.varp,
)
col = adata.obs[group_key]
if hasattr(col, "cat"):
col = col.cat.add_categories(na_label)
adata.obs[group_key] = col.fillna(na_label)
self.adata = adata

(
self.gene_symbols_column,
self._original_gene_symbols_column,
Expand Down
28 changes: 28 additions & 0 deletions cytetype/preprocessing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,43 @@ def _ur_sort_key(ur: float) -> float:
return None


def _generate_unique_na_label(existing_labels: set[str]) -> str:
label = "Unknown"
if label not in existing_labels:
return label
n = 2
while f"{label} {n}" in existing_labels:
n += 1
return f"{label} {n}"


def validate_adata(
adata: anndata.AnnData,
cell_group_key: str,
rank_genes_key: str,
gene_symbols_col: str | None,
coordinates_key: str,
label_na: bool = False,
) -> str | None:
if cell_group_key not in adata.obs:
raise KeyError(f"Cell group key '{cell_group_key}' not found in `adata.obs`.")

nan_mask = adata.obs[cell_group_key].isna()
n_nan = int(nan_mask.sum())
if n_nan > 0:
pct = round(100 * n_nan / adata.n_obs, 1)
if n_nan == adata.n_obs:
raise ValueError(
f"All {n_nan} cells have NaN values in '{cell_group_key}'. "
f"Cannot proceed with annotation."
)
if not label_na:
raise ValueError(
f"{n_nan} cells ({pct}%) have NaN values in '{cell_group_key}'. "
f"Either fix the data or set label_na=True to assign these cells "
f"an 'Unknown' cluster label."
)

Comment on lines +290 to +305

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action required

1. Unexpected cell dropping 🐞 Bug ✓ Correctness

validate_adata now subsets adata in place to drop rows with NaN in cell_group_key, permanently
removing observations from the user-provided AnnData. This contradicts CyteType.annotate()'s
documented behavior of only adding results to the input AnnData and can silently return an object
missing cells.
Agent Prompt
### Issue description
`validate_adata()` currently drops observations with NaN `cell_group_key` values by subsetting the passed-in `AnnData` in place. This is a breaking, destructive side effect for callers (notably `CyteType`), and it contradicts the documented contract that `annotate()` only adds fields to the input AnnData.

### Issue Context
The current behavior permanently removes cells from the user’s AnnData during `CyteType.__init__`, so downstream code and the returned AnnData from `annotate()` no longer contain the same set of observations.

### Fix Focus Areas
- cytetype/preprocessing/validation.py[279-292]
- cytetype/main.py[133-176]
- cytetype/main.py[467-471]

### Suggested fix approach
Implement one of these explicit, non-surprising behaviors:
1) **Strict validation (recommended):** Do not modify `adata`; instead raise a `ValueError` when NaNs are present, instructing the user to pre-clean `adata.obs[cell_group_key]`.
2) **Opt-in dropping:** Add an explicit flag (e.g., `drop_nan_cells: bool = False`) on the public entrypoint (`CyteType.__init__` or a config object). Only when `True`, drop cells, and **update the docstring** to state that observations may be removed.
3) **Non-destructive internal copy:** If dropping is desired by default, operate on a copied AnnData for internal computation (`self.adata = adata[mask].copy()`), and clearly document this behavior (including how dropped cell indices are reported back, e.g., `adata.uns['cytetype_dropped_obs']`).

ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools

if adata.X is None:
raise ValueError(
"`adata.X` is required for ranking genes. Please ensure it contains log1p normalized data."
Expand Down
Loading