Skip to content

Commit 47dfee8

Browse files
authored
Add mapping utilities for xenium analysis
1 parent 70c9a17 commit 47dfee8

1 file changed

Lines changed: 240 additions & 0 deletions

File tree

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import warnings
2+
3+
# Suppress noisy cell-type-mapper warnings
4+
warnings.filterwarnings("ignore", message=".*not listed in marker lookup.*")
5+
warnings.filterwarnings("ignore", message=".*had too few markers in query set.*")
6+
7+
from cell_type_mapper.cli.transcribe_to_obs import TranscribeToObsRunner
8+
from cell_type_mapper.cli.validate_h5ad import ValidateH5adRunner
9+
10+
import json
11+
import spatialdata as sd
12+
import pandas as pd
13+
import numpy as np
14+
import anndata as ad
15+
import tempfile
16+
import matplotlib.pyplot as plt
17+
import seaborn as sns
18+
from matplotlib.patches import Rectangle
19+
from pathlib import Path
20+
from tqdm import tqdm
21+
22+
def combine_sections_adatas(sections_paths):
23+
# Initialize
24+
sections_adatas: list[ad.AnnData] = []
25+
26+
# Loop through sections
27+
for s_path in tqdm(sections_paths, desc='Loading section adatas'):
28+
# Load anndata from table
29+
adata = ad.read_zarr(s_path / 'tables' / 'table')
30+
31+
# Format section/cell ids
32+
section_num = int(s_path.stem.split('_')[1])
33+
adata.obs['section'] = section_num
34+
adata.obs['original_cell_id'] = adata.obs['cell_id']
35+
adata.obs['cell_id'] = [f"{c_id}_{sec}" for c_id, sec in zip(adata.obs['cell_id'], adata.obs['section'])]
36+
adata.obs.set_index('cell_id', inplace=True, drop=False)
37+
38+
# Clear bulky containers
39+
adata.uns = {}
40+
adata.obsm = {}
41+
sections_adatas.append(adata)
42+
del adata
43+
44+
# Combine all sections into a single AnnData object
45+
sections_adatas = ad.concat(
46+
sections_adatas,
47+
axis=0,
48+
join="outer",
49+
merge="same"
50+
)
51+
sections_adatas.var['gene_symbol'] = sections_adatas.var.index
52+
sections_adatas.var.set_index(sections_adatas.var['gene_ids'], inplace=True, drop=False)
53+
54+
return sections_adatas
55+
56+
57+
def get_drop_nodes(v1_types_config, abc_cache, save_plot=False, output_folder=None):
58+
# Taxonomy filters - nodes to drop for mapping
59+
nodes_to_drop=[]
60+
61+
# If specified any specific nodes to drop in config
62+
drop_nodes_dict = v1_types_config.get('drop_nodes_dict', None)
63+
if drop_nodes_dict:
64+
for h_level in drop_nodes_dict:
65+
nodes_to_drop.extend([(h_level, cl) for cl in drop_nodes_dict[h_level]])
66+
print(f"Dropping {len(nodes_to_drop)} nodes based on drop_nodes_dict.")
67+
68+
print("Filtering to only include V1 cell type nodes...")
69+
v1_types_path = v1_types_config.get('v1_types_path', '/root/capsule/code/v1_merfish_cells.csv')
70+
h_level = v1_types_config.get('h_level', 'subclass')
71+
min_cells = v1_types_config.get('min_cells', 0)
72+
if Path(v1_types_path).exists():
73+
print(f"Loading V1 MERFISH cell types from {v1_types_path}...")
74+
v1_merfish_cells = pd.read_csv(v1_types_path)
75+
else:
76+
print(f"V1 MERFISH cell types file not found at {v1_types_path}. Attempting to generate it from ABC cache...")
77+
v1_merfish_cells = get_v1_merfish_cells(abc_cache, df_path=v1_types_path)
78+
if save_plot:
79+
plot_cell_counts_heatmap(v1_merfish_cells, min_cells=min_cells, save_path=output_folder / 'v1_merfish_cell_counts_heatmap.svg')
80+
81+
# Filter out specified layers if provided in config
82+
if v1_types_config.get('drop_layers', None):
83+
v1_merfish_cells = v1_merfish_cells.loc[~v1_merfish_cells['parcellation_substructure'].isin(v1_types_config.get('drop_layers'))]
84+
85+
# Filter df to only include rows where cell count is above min_cells threshold if specified
86+
v1_nodes_to_drop = get_nodes_to_drop(v1_merfish_cells, abc_cache, h_level=h_level, min_cells=min_cells)
87+
print(f"Dropping {len(v1_nodes_to_drop)} {h_level} nodes not present in V1 MERFISH data with at least {min_cells if min_cells>0 else 1} cell(s).")
88+
nodes_to_drop.extend(v1_nodes_to_drop)
89+
90+
return nodes_to_drop
91+
92+
def get_abc_paths(abc_cache):
93+
try:
94+
precomputed_stats_path = abc_cache.get_data_path(
95+
directory='WMB-10X',
96+
file_name='precomputed_stats_ABC_revision_230821'
97+
)
98+
mouse_markers_path = abc_cache.get_data_path(
99+
directory='WMB-10X',
100+
file_name='mouse_markers_230821'
101+
)
102+
gene_mapper_db_path = abc_cache.get_data_path(
103+
directory='mmc-gene-mapper',
104+
file_name='mmc_gene_mapper.2025-08-04'
105+
)
106+
except Exception as e:
107+
precomputed_stats_path = '/root/capsule/data/abc_atlas/mapmycells/WMB-10X/20240831/precomputed_stats_ABC_revision_230821.h5'
108+
mouse_markers_path = '/root/capsule/data/abc_atlas/mapmycells/WMB-10X/20240831/mouse_markers_230821.json'
109+
gene_mapper_db_path = '/root/capsule/data/abc_atlas/mapmycells/mmc-gene-mapper/20250630/mmc_gene_mapper.2025-08-04.db'
110+
return precomputed_stats_path, mouse_markers_path, gene_mapper_db_path
111+
112+
113+
def get_v1_merfish_cells(abc_cache=None, df_path=None):
114+
if df_path and df_path.exists():
115+
v1_merfish_cells = pd.read_csv(df_path, index_col=0)
116+
else:
117+
# Get MERFISH CCF metadata
118+
print('V1 cell df not found, generating new one...')
119+
if abc_cache is None:
120+
raise ValueError("abc_cache must be provided if path to df does not exist")
121+
merfish_ccf_metadata = abc_cache.get_metadata_dataframe(
122+
directory='MERFISH-C57BL6J-638850-CCF',
123+
file_name='cell_metadata_with_parcellation_annotation'
124+
).set_index('cell_label')
125+
v1_merfish_cells = merfish_ccf_metadata.loc[merfish_ccf_metadata['parcellation_structure']=='VISp']
126+
# Save created df
127+
print(f"Saving df to: {df_path}")
128+
v1_merfish_cells.to_csv(df_path)
129+
return v1_merfish_cells
130+
131+
def get_nodes_to_drop(cells_df, abc_cache, h_level='subclass', min_cells=0):
132+
# Load the taxonomy
133+
taxonomy_df = abc_cache.get_metadata_dataframe(
134+
directory='WMB-taxonomy',
135+
file_name='cluster_to_cluster_annotation_membership'
136+
)
137+
# Group the cells by the specified hierarchy level
138+
grouped_cells = cells_df.groupby(h_level).size()
139+
140+
# Identify clusters below the minimum cell threshold
141+
filtered_clusters = grouped_cells[grouped_cells>=min_cells].index.tolist()
142+
143+
# Get the valid clusters from the taxonomy
144+
valid_clusters = [np.unique(taxonomy_df.loc[taxonomy_df['cluster_annotation_term_name']==cl,'cluster_annotation_term_label'])[0] for cl in filtered_clusters]
145+
146+
# Determine nodes to drop
147+
clusters_to_drop = np.setdiff1d(taxonomy_df.loc[taxonomy_df['cluster_annotation_term_set_name']==h_level,'cluster_annotation_term_label'].unique(), valid_clusters)
148+
nodes_to_drop = [(h_level, cl) for cl in clusters_to_drop]
149+
return nodes_to_drop
150+
151+
def validate_input_adata(h5ad_path, output_dir, mouse_markers_path, db_path, round_to_int=False, layer='X', output_json=None):
152+
if output_json is None:
153+
output_json = tempfile.mkstemp(suffix='.json')[1]
154+
validate_config = {
155+
'h5ad_path': str(h5ad_path),
156+
'round_to_int': round_to_int,
157+
'layer': layer,
158+
'output_dir': str(output_dir),
159+
'output_json': output_json,
160+
'gene_mapping': {'db_path': db_path}
161+
}
162+
try:
163+
print("Starting validation of input data...")
164+
validation_runner = ValidateH5adRunner(args=[], input_data=validate_config)
165+
validation_runner.run()
166+
validated_path = json.load(open(validate_config['output_json'], 'rb'))['valid_h5ad_path']
167+
markers_json = json.load(open(mouse_markers_path, 'rb'))
168+
markers = []
169+
for key in list(markers_json.keys())[:-2]:
170+
markers += markers_json[key]
171+
use_validated_h5ad = any('ENSMUSG' in marker for marker in markers)
172+
if use_validated_h5ad:
173+
print(f"Using validated path: {validated_path}")
174+
return validated_path
175+
else:
176+
print(f'Using original path: {h5ad_path}')
177+
return h5ad_path
178+
except Exception as e:
179+
print(f"Validation failed: {e}")
180+
raise e
181+
182+
def format_mapping_outputs(extended_results_path, mapped_adata_path, mapping_params, h5ad_path=None):
183+
try:
184+
if h5ad_path is None:
185+
with open(extended_results_path, 'r') as f:
186+
extended_results = json.load(f)
187+
h5ad_path = extended_results['config']['query_path']
188+
ad_config = {
189+
'h5ad_path': str(h5ad_path),
190+
'result_path': str(extended_results_path),
191+
'new_h5ad_path': str(mapped_adata_path),
192+
'clobber': bool(mapping_params['clobber'])
193+
}
194+
TranscribeToObsRunner(args=[], input_data=ad_config).run()
195+
return True
196+
except Exception as e:
197+
print(f"Formatting mapping outputs failed: {e}")
198+
return False
199+
200+
201+
def plot_cell_counts_heatmap(v1_merfish_cells, min_cells=5, save_path=None):
202+
if isinstance(v1_merfish_cells, str) or isinstance(v1_merfish_cells, Path):
203+
v1_merfish_cells = pd.read_csv(v1_merfish_cells)
204+
205+
# Get the cell counts
206+
cell_counts = v1_merfish_cells.groupby(['subclass','parcellation_substructure']).size().reset_index(name='cell_count')
207+
pivot_data = cell_counts.pivot(index='parcellation_substructure', columns='subclass', values='cell_count').fillna(0)
208+
209+
# Add total row (sum across all parcellation substructures)
210+
total_row = pivot_data.sum(axis=0)
211+
total_row.name = 'Total (All Layers)'
212+
pivot_data_with_total = pd.concat([pivot_data, total_row.to_frame().T])
213+
214+
# Create the plot with total row
215+
plt.figure(figsize=(20, 6))
216+
ax = sns.heatmap(pivot_data_with_total, annot=True, fmt='.0f', cmap='viridis',
217+
cbar_kws={'label': 'Cell Count'})
218+
plt.title('Cell Count Heatmap: Subclass vs Parcellation Substructure')
219+
plt.xlabel('Subclass')
220+
plt.ylabel('Parcellation Substructure')
221+
222+
# Highlight the total row with a different color or border
223+
ax.axhline(y=len(pivot_data), color='white', linewidth=3)
224+
225+
# Add boxes around columns where total is below threshold
226+
for i, (subclass, total_count) in enumerate(total_row.items()):
227+
if total_count < min_cells:
228+
# Draw rectangle around entire column
229+
rect = Rectangle((i, 0), 1, len(pivot_data_with_total),
230+
linewidth=0, edgecolor='None', facecolor='gray',
231+
linestyle='-', alpha=0.65)
232+
ax.add_patch(rect)
233+
234+
plt.tight_layout()
235+
236+
# Save as SVG if requested
237+
if save_path is not None:
238+
plt.savefig(str(save_path), format='svg', bbox_inches='tight')
239+
240+
plt.show()

0 commit comments

Comments
 (0)