|
| 1 | +import warnings |
| 2 | + |
| 3 | +# Suppress noisy cell-type-mapper warnings |
| 4 | +warnings.filterwarnings("ignore", message=".*not listed in marker lookup.*") |
| 5 | +warnings.filterwarnings("ignore", message=".*had too few markers in query set.*") |
| 6 | + |
| 7 | +from cell_type_mapper.cli.transcribe_to_obs import TranscribeToObsRunner |
| 8 | +from cell_type_mapper.cli.validate_h5ad import ValidateH5adRunner |
| 9 | + |
| 10 | +import json |
| 11 | +import spatialdata as sd |
| 12 | +import pandas as pd |
| 13 | +import numpy as np |
| 14 | +import anndata as ad |
| 15 | +import tempfile |
| 16 | +import matplotlib.pyplot as plt |
| 17 | +import seaborn as sns |
| 18 | +from matplotlib.patches import Rectangle |
| 19 | +from pathlib import Path |
| 20 | +from tqdm import tqdm |
| 21 | + |
| 22 | +def combine_sections_adatas(sections_paths): |
| 23 | + # Initialize |
| 24 | + sections_adatas: list[ad.AnnData] = [] |
| 25 | + |
| 26 | + # Loop through sections |
| 27 | + for s_path in tqdm(sections_paths, desc='Loading section adatas'): |
| 28 | + # Load anndata from table |
| 29 | + adata = ad.read_zarr(s_path / 'tables' / 'table') |
| 30 | + |
| 31 | + # Format section/cell ids |
| 32 | + section_num = int(s_path.stem.split('_')[1]) |
| 33 | + adata.obs['section'] = section_num |
| 34 | + adata.obs['original_cell_id'] = adata.obs['cell_id'] |
| 35 | + adata.obs['cell_id'] = [f"{c_id}_{sec}" for c_id, sec in zip(adata.obs['cell_id'], adata.obs['section'])] |
| 36 | + adata.obs.set_index('cell_id', inplace=True, drop=False) |
| 37 | + |
| 38 | + # Clear bulky containers |
| 39 | + adata.uns = {} |
| 40 | + adata.obsm = {} |
| 41 | + sections_adatas.append(adata) |
| 42 | + del adata |
| 43 | + |
| 44 | + # Combine all sections into a single AnnData object |
| 45 | + sections_adatas = ad.concat( |
| 46 | + sections_adatas, |
| 47 | + axis=0, |
| 48 | + join="outer", |
| 49 | + merge="same" |
| 50 | + ) |
| 51 | + sections_adatas.var['gene_symbol'] = sections_adatas.var.index |
| 52 | + sections_adatas.var.set_index(sections_adatas.var['gene_ids'], inplace=True, drop=False) |
| 53 | + |
| 54 | + return sections_adatas |
| 55 | + |
| 56 | + |
| 57 | +def get_drop_nodes(v1_types_config, abc_cache, save_plot=False, output_folder=None): |
| 58 | + # Taxonomy filters - nodes to drop for mapping |
| 59 | + nodes_to_drop=[] |
| 60 | + |
| 61 | + # If specified any specific nodes to drop in config |
| 62 | + drop_nodes_dict = v1_types_config.get('drop_nodes_dict', None) |
| 63 | + if drop_nodes_dict: |
| 64 | + for h_level in drop_nodes_dict: |
| 65 | + nodes_to_drop.extend([(h_level, cl) for cl in drop_nodes_dict[h_level]]) |
| 66 | + print(f"Dropping {len(nodes_to_drop)} nodes based on drop_nodes_dict.") |
| 67 | + |
| 68 | + print("Filtering to only include V1 cell type nodes...") |
| 69 | + v1_types_path = v1_types_config.get('v1_types_path', '/root/capsule/code/v1_merfish_cells.csv') |
| 70 | + h_level = v1_types_config.get('h_level', 'subclass') |
| 71 | + min_cells = v1_types_config.get('min_cells', 0) |
| 72 | + if Path(v1_types_path).exists(): |
| 73 | + print(f"Loading V1 MERFISH cell types from {v1_types_path}...") |
| 74 | + v1_merfish_cells = pd.read_csv(v1_types_path) |
| 75 | + else: |
| 76 | + print(f"V1 MERFISH cell types file not found at {v1_types_path}. Attempting to generate it from ABC cache...") |
| 77 | + v1_merfish_cells = get_v1_merfish_cells(abc_cache, df_path=v1_types_path) |
| 78 | + if save_plot: |
| 79 | + plot_cell_counts_heatmap(v1_merfish_cells, min_cells=min_cells, save_path=output_folder / 'v1_merfish_cell_counts_heatmap.svg') |
| 80 | + |
| 81 | + # Filter out specified layers if provided in config |
| 82 | + if v1_types_config.get('drop_layers', None): |
| 83 | + v1_merfish_cells = v1_merfish_cells.loc[~v1_merfish_cells['parcellation_substructure'].isin(v1_types_config.get('drop_layers'))] |
| 84 | + |
| 85 | + # Filter df to only include rows where cell count is above min_cells threshold if specified |
| 86 | + v1_nodes_to_drop = get_nodes_to_drop(v1_merfish_cells, abc_cache, h_level=h_level, min_cells=min_cells) |
| 87 | + print(f"Dropping {len(v1_nodes_to_drop)} {h_level} nodes not present in V1 MERFISH data with at least {min_cells if min_cells>0 else 1} cell(s).") |
| 88 | + nodes_to_drop.extend(v1_nodes_to_drop) |
| 89 | + |
| 90 | + return nodes_to_drop |
| 91 | + |
| 92 | +def get_abc_paths(abc_cache): |
| 93 | + try: |
| 94 | + precomputed_stats_path = abc_cache.get_data_path( |
| 95 | + directory='WMB-10X', |
| 96 | + file_name='precomputed_stats_ABC_revision_230821' |
| 97 | + ) |
| 98 | + mouse_markers_path = abc_cache.get_data_path( |
| 99 | + directory='WMB-10X', |
| 100 | + file_name='mouse_markers_230821' |
| 101 | + ) |
| 102 | + gene_mapper_db_path = abc_cache.get_data_path( |
| 103 | + directory='mmc-gene-mapper', |
| 104 | + file_name='mmc_gene_mapper.2025-08-04' |
| 105 | + ) |
| 106 | + except Exception as e: |
| 107 | + precomputed_stats_path = '/root/capsule/data/abc_atlas/mapmycells/WMB-10X/20240831/precomputed_stats_ABC_revision_230821.h5' |
| 108 | + mouse_markers_path = '/root/capsule/data/abc_atlas/mapmycells/WMB-10X/20240831/mouse_markers_230821.json' |
| 109 | + gene_mapper_db_path = '/root/capsule/data/abc_atlas/mapmycells/mmc-gene-mapper/20250630/mmc_gene_mapper.2025-08-04.db' |
| 110 | + return precomputed_stats_path, mouse_markers_path, gene_mapper_db_path |
| 111 | + |
| 112 | + |
| 113 | +def get_v1_merfish_cells(abc_cache=None, df_path=None): |
| 114 | + if df_path and df_path.exists(): |
| 115 | + v1_merfish_cells = pd.read_csv(df_path, index_col=0) |
| 116 | + else: |
| 117 | + # Get MERFISH CCF metadata |
| 118 | + print('V1 cell df not found, generating new one...') |
| 119 | + if abc_cache is None: |
| 120 | + raise ValueError("abc_cache must be provided if path to df does not exist") |
| 121 | + merfish_ccf_metadata = abc_cache.get_metadata_dataframe( |
| 122 | + directory='MERFISH-C57BL6J-638850-CCF', |
| 123 | + file_name='cell_metadata_with_parcellation_annotation' |
| 124 | + ).set_index('cell_label') |
| 125 | + v1_merfish_cells = merfish_ccf_metadata.loc[merfish_ccf_metadata['parcellation_structure']=='VISp'] |
| 126 | + # Save created df |
| 127 | + print(f"Saving df to: {df_path}") |
| 128 | + v1_merfish_cells.to_csv(df_path) |
| 129 | + return v1_merfish_cells |
| 130 | + |
| 131 | +def get_nodes_to_drop(cells_df, abc_cache, h_level='subclass', min_cells=0): |
| 132 | + # Load the taxonomy |
| 133 | + taxonomy_df = abc_cache.get_metadata_dataframe( |
| 134 | + directory='WMB-taxonomy', |
| 135 | + file_name='cluster_to_cluster_annotation_membership' |
| 136 | + ) |
| 137 | + # Group the cells by the specified hierarchy level |
| 138 | + grouped_cells = cells_df.groupby(h_level).size() |
| 139 | + |
| 140 | + # Identify clusters below the minimum cell threshold |
| 141 | + filtered_clusters = grouped_cells[grouped_cells>=min_cells].index.tolist() |
| 142 | + |
| 143 | + # Get the valid clusters from the taxonomy |
| 144 | + valid_clusters = [np.unique(taxonomy_df.loc[taxonomy_df['cluster_annotation_term_name']==cl,'cluster_annotation_term_label'])[0] for cl in filtered_clusters] |
| 145 | + |
| 146 | + # Determine nodes to drop |
| 147 | + clusters_to_drop = np.setdiff1d(taxonomy_df.loc[taxonomy_df['cluster_annotation_term_set_name']==h_level,'cluster_annotation_term_label'].unique(), valid_clusters) |
| 148 | + nodes_to_drop = [(h_level, cl) for cl in clusters_to_drop] |
| 149 | + return nodes_to_drop |
| 150 | + |
| 151 | +def validate_input_adata(h5ad_path, output_dir, mouse_markers_path, db_path, round_to_int=False, layer='X', output_json=None): |
| 152 | + if output_json is None: |
| 153 | + output_json = tempfile.mkstemp(suffix='.json')[1] |
| 154 | + validate_config = { |
| 155 | + 'h5ad_path': str(h5ad_path), |
| 156 | + 'round_to_int': round_to_int, |
| 157 | + 'layer': layer, |
| 158 | + 'output_dir': str(output_dir), |
| 159 | + 'output_json': output_json, |
| 160 | + 'gene_mapping': {'db_path': db_path} |
| 161 | + } |
| 162 | + try: |
| 163 | + print("Starting validation of input data...") |
| 164 | + validation_runner = ValidateH5adRunner(args=[], input_data=validate_config) |
| 165 | + validation_runner.run() |
| 166 | + validated_path = json.load(open(validate_config['output_json'], 'rb'))['valid_h5ad_path'] |
| 167 | + markers_json = json.load(open(mouse_markers_path, 'rb')) |
| 168 | + markers = [] |
| 169 | + for key in list(markers_json.keys())[:-2]: |
| 170 | + markers += markers_json[key] |
| 171 | + use_validated_h5ad = any('ENSMUSG' in marker for marker in markers) |
| 172 | + if use_validated_h5ad: |
| 173 | + print(f"Using validated path: {validated_path}") |
| 174 | + return validated_path |
| 175 | + else: |
| 176 | + print(f'Using original path: {h5ad_path}') |
| 177 | + return h5ad_path |
| 178 | + except Exception as e: |
| 179 | + print(f"Validation failed: {e}") |
| 180 | + raise e |
| 181 | + |
| 182 | +def format_mapping_outputs(extended_results_path, mapped_adata_path, mapping_params, h5ad_path=None): |
| 183 | + try: |
| 184 | + if h5ad_path is None: |
| 185 | + with open(extended_results_path, 'r') as f: |
| 186 | + extended_results = json.load(f) |
| 187 | + h5ad_path = extended_results['config']['query_path'] |
| 188 | + ad_config = { |
| 189 | + 'h5ad_path': str(h5ad_path), |
| 190 | + 'result_path': str(extended_results_path), |
| 191 | + 'new_h5ad_path': str(mapped_adata_path), |
| 192 | + 'clobber': bool(mapping_params['clobber']) |
| 193 | + } |
| 194 | + TranscribeToObsRunner(args=[], input_data=ad_config).run() |
| 195 | + return True |
| 196 | + except Exception as e: |
| 197 | + print(f"Formatting mapping outputs failed: {e}") |
| 198 | + return False |
| 199 | + |
| 200 | + |
| 201 | +def plot_cell_counts_heatmap(v1_merfish_cells, min_cells=5, save_path=None): |
| 202 | + if isinstance(v1_merfish_cells, str) or isinstance(v1_merfish_cells, Path): |
| 203 | + v1_merfish_cells = pd.read_csv(v1_merfish_cells) |
| 204 | + |
| 205 | + # Get the cell counts |
| 206 | + cell_counts = v1_merfish_cells.groupby(['subclass','parcellation_substructure']).size().reset_index(name='cell_count') |
| 207 | + pivot_data = cell_counts.pivot(index='parcellation_substructure', columns='subclass', values='cell_count').fillna(0) |
| 208 | + |
| 209 | + # Add total row (sum across all parcellation substructures) |
| 210 | + total_row = pivot_data.sum(axis=0) |
| 211 | + total_row.name = 'Total (All Layers)' |
| 212 | + pivot_data_with_total = pd.concat([pivot_data, total_row.to_frame().T]) |
| 213 | + |
| 214 | + # Create the plot with total row |
| 215 | + plt.figure(figsize=(20, 6)) |
| 216 | + ax = sns.heatmap(pivot_data_with_total, annot=True, fmt='.0f', cmap='viridis', |
| 217 | + cbar_kws={'label': 'Cell Count'}) |
| 218 | + plt.title('Cell Count Heatmap: Subclass vs Parcellation Substructure') |
| 219 | + plt.xlabel('Subclass') |
| 220 | + plt.ylabel('Parcellation Substructure') |
| 221 | + |
| 222 | + # Highlight the total row with a different color or border |
| 223 | + ax.axhline(y=len(pivot_data), color='white', linewidth=3) |
| 224 | + |
| 225 | + # Add boxes around columns where total is below threshold |
| 226 | + for i, (subclass, total_count) in enumerate(total_row.items()): |
| 227 | + if total_count < min_cells: |
| 228 | + # Draw rectangle around entire column |
| 229 | + rect = Rectangle((i, 0), 1, len(pivot_data_with_total), |
| 230 | + linewidth=0, edgecolor='None', facecolor='gray', |
| 231 | + linestyle='-', alpha=0.65) |
| 232 | + ax.add_patch(rect) |
| 233 | + |
| 234 | + plt.tight_layout() |
| 235 | + |
| 236 | + # Save as SVG if requested |
| 237 | + if save_path is not None: |
| 238 | + plt.savefig(str(save_path), format='svg', bbox_inches='tight') |
| 239 | + |
| 240 | + plt.show() |
0 commit comments