Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"pybiomart>=0.2.0",
"requests>=2.32.4",
"scipy>=1.15.3",
"tomotopy>=0.14.0",
"tmtoolkit>=0.12.0",
]

Expand Down
662 changes: 303 additions & 359 deletions src/pycisTopic/cli/subcommand/topic_modeling.py

Large diffs are not rendered by default.

74 changes: 26 additions & 48 deletions src/pycisTopic/topic_modeling/create_anndata.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,52 @@
from __future__ import annotations

import anndata
import pandas as pd

from pycisTopic.topic_modeling.mallet_models import LDAMallet, LDAMalletFilenames
from pycisTopic.topic_modeling.topic_models import TopicModelFilenames, load_topic_model_backend


def create_anndata_from_mallet(
def create_anndata_from_topic_model(
output_prefix: str,
n_topics: int,
cell_barcodes: list[str],
region_ids: list[str]
):
"""
Create an AnnData object from mallet topic modeling results.

Parameters
----------
output_prefix
Output prefix used for running topic modeling with Mallet.
n_topics
Number of topics used in the topic model created by Mallet.
In combination with output_prefix, this allows to load the correct region
topic counts and cell topic probabilties parquet files.
cell_barcodes
List containing cell names as ordered in the binary matrix columns.
region_ids
List containing region names as ordered in the binary matrix rows.

Return
------
None

"""
# Get distributions
print("Reading Mallet results ...")
lda_mallet_filenames = LDAMalletFilenames(
output_prefix=output_prefix, n_topics=n_topics
region_ids: list[str],
) -> None:
"""Create AnnData objects from backend-agnostic v3 topic modeling artifacts."""
filenames = TopicModelFilenames(output_prefix=output_prefix, n_topics=n_topics)
backend_cls = load_topic_model_backend(output_prefix=output_prefix, n_topics=n_topics)

print(f"Reading {backend_cls.backend} results ...")
topic_word_distrib = (
backend_cls.read_region_topic_counts_parquet_file_to_region_topic_probabilities(
region_topic_counts_parquet_filename=filenames.region_topic_counts_parquet_filename
)
)
topic_word_distrib = LDAMallet.read_region_topic_counts_parquet_file_to_region_topic_probabilities(
mallet_region_topic_counts_parquet_filename=lda_mallet_filenames.region_topic_counts_parquet_filename
)
doc_topic_distrib = LDAMallet.read_cell_topic_probabilities_parquet_file(
mallet_cell_topic_probabilities_parquet_filename=lda_mallet_filenames.cell_topic_probabilities_parquet_filename
doc_topic_distrib = backend_cls.read_cell_topic_probabilities_parquet_file(
cell_topic_probabilities_parquet_filename=filenames.cell_topic_probabilities_parquet_filename
)

cell_topic = pd.DataFrame.from_records(
doc_topic_distrib,
index=cell_barcodes,
columns=["Topic" + str(i) for i in range(1, n_topics + 1)],
columns=[f"Topic{i}" for i in range(1, n_topics + 1)],
)

region_topic = pd.DataFrame.from_records(
topic_word_distrib,
columns=region_ids,
index=["Topic" + str(i) for i in range(1, n_topics + 1)],
index=[f"Topic{i}" for i in range(1, n_topics + 1)],
).transpose()

print("Generating cell_topic AnnData object")
adata_cell_topic = anndata.AnnData(
X=cell_topic
)
print("Generating cell topic AnnData object")
adata_cell_topic = anndata.AnnData(X=cell_topic)
print(f"Done, shape is: {adata_cell_topic.shape}")

print("Generating region topic AnnData object")
adata_region_topic = anndata.AnnData(
X=region_topic
)
adata_region_topic = anndata.AnnData(X=region_topic)
print(f"Done, shape is: {adata_region_topic.shape}")

print(f"Writing to: {lda_mallet_filenames.anndata_cell_topic_filename}")
adata_cell_topic.write(lda_mallet_filenames.anndata_cell_topic_filename)
print(f"Writing to: {filenames.anndata_cell_topic_filename}")
adata_cell_topic.write(filenames.anndata_cell_topic_filename)

print(f"Writing to: {lda_mallet_filenames.anndata_region_topic_filename}")
adata_region_topic.write(lda_mallet_filenames.anndata_region_topic_filename)
print(f"Writing to: {filenames.anndata_region_topic_filename}")
adata_region_topic.write(filenames.anndata_region_topic_filename)
Loading