Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/hnoca/mapping/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
ref_model: The reference model to map the query dataset to.
"""
# Check optional dependencies
check_deps("scvi-tools")
check_deps("scarches")
# check_deps("scvi-tools") # Comment as it is a function that fails to detect scvi-tools was installed
# check_deps("scarches") # Comment as it is a function that fails to detect scArches was installed
# Import and store as attributes so other methods can use them
import scarches
import scvi # local import assured by previous check
Expand Down
15 changes: 9 additions & 6 deletions src/hnoca/mapping/wknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def build_nn( # noqa: D103
ref,
query=None,
k=100,
use_rapids: bool = False,
weight: Literal["unweighted", "dist", "gaussian_kernel"] = "unweighted",
sigma=None,
use_rapids: bool = True, # Ensure that RAPIDS is used
):
if query is None:
query = ref

if use_rapids:
check_deps("cuml")
# check_deps("cuml") # Comment check_deps() because the function is broken
from cuml.neighbors import NearestNeighbors

logger.info("Using cuML for neighborhood estimation on GPU.")
Expand Down Expand Up @@ -223,9 +225,9 @@ def estimate_presence_score(
ref = ref_adata.obsm[use_rep_ref_trans_prop]
ref_trans_prop = get_transition_prob_mat(ref, k=k_ref_trans_prop)

if split_by and split_by in query_adata.obs.columns:
if split_by in query_adata.obs.columns:
presence_split = [
np.array(wknn[query_adata.obs[split_by] == x, :].sum(axis=0)).flatten()
np.array(wknn[query_adata.obs[split_by].to_numpy() == x, :].sum(axis=0)).flatten() # added to_numpy() for better compatibility
for x in query_adata.obs[split_by].unique()
]
else:
Expand Down Expand Up @@ -270,13 +272,14 @@ def estimate_presence_score(
}


def transfer_labels(ref_adata: sc.AnnData, query_adata: sc.AnnData, wknn, label_key: str = "celltype"):
def transfer_labels(ref_adata: sc.AnnData, query_adata: sc.AnnData, wknn, label_key: str ="celltype"):
"""Transfer labels from reference to query data."""
scores = pd.DataFrame(
wknn @ pd.get_dummies(ref_adata.obs[label_key]),
columns=pd.get_dummies(ref_adata.obs[label_key]).columns,
index=query_adata.obs_names,
)
scores["best_score"] = scores.max(1) # change order, first find the score then the label so no string is inputted among floats
scores["best_label"] = scores.idxmax(1)
scores["best_score"] = scores.max(1)

return scores