Skip to content

Commit ec09641

Browse files
timtreisclaude
andcommitted
Fix categorical color mapping crash when categories differ across coordinate systems
When plotting shapes colored by a categorical column across multiple coordinate systems, the color_source_vector could carry unused categories from other coordinate systems. This caused a length mismatch between categories and stored colors in adata.uns, leading to a ValueError in strict zip calls. The fix removes unused categories early in both _generate_base_categorial_color_mapping and _extract_colors_from_table_uns, and maps colors by the category's position in the full table (not the subset) to ensure consistent coloring across coordinate systems. Fixes #425 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a2bb56b commit ec09641

1 file changed

Lines changed: 32 additions & 6 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,10 +1213,24 @@ def _generate_base_categorial_color_mapping(
12131213
cmap_params: CmapParams | None = None,
12141214
) -> Mapping[str, str]:
12151215
if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns:
1216-
colors = adata.uns[f"{cluster_key}_colors"]
1217-
categories = color_source_vector.categories.tolist() + ["NaN"]
1216+
all_colors = adata.uns[f"{cluster_key}_colors"]
1217+
1218+
# When plotting per-coordinate-system, the color_source_vector may carry
1219+
# categories from other coordinate systems that aren't present in the
1220+
# current subset. Drop them so that categories and colors stay aligned.
1221+
color_source_vector = color_source_vector.remove_unused_categories()
1222+
1223+
# The stored colors in .uns correspond 1-to-1 to the *full* set of
1224+
# categories in adata.obs[cluster_key]. Subset to the categories that
1225+
# are still present after removing unused ones.
1226+
if cluster_key in adata.obs and hasattr(adata.obs[cluster_key], "cat"):
1227+
all_cats = adata.obs[cluster_key].cat.categories.tolist()
1228+
keep_idx = [i for i, c in enumerate(all_cats) if c in color_source_vector.categories]
1229+
colors = [to_hex(to_rgba(all_colors[i])[:3]) for i in keep_idx]
1230+
else:
1231+
colors = [to_hex(to_rgba(c)[:3]) for c in all_colors]
12181232

1219-
colors = [to_hex(to_rgba(color)[:3]) for color in colors]
1233+
categories = color_source_vector.categories.tolist() + ["NaN"]
12201234

12211235
if len(categories) > len(colors):
12221236
return dict(zip(categories, colors + [na_color.get_hex_with_alpha()], strict=True))
@@ -1331,6 +1345,9 @@ def _extract_colors_from_table_uns(
13311345

13321346
# Extract colors and categories
13331347
stored_colors = adata.uns[color_key]
1348+
# Drop categories not present in the current subset (e.g. when plotting
1349+
# per-coordinate-system) so that positional color lookups stay aligned.
1350+
color_source_vector = color_source_vector.remove_unused_categories()
13341351
categories = color_source_vector.categories.tolist()
13351352

13361353
# Validate na_color format and convert to hex string
@@ -1378,9 +1395,18 @@ def _to_hex_no_alpha(color_value: Any) -> str | None:
13781395
logger.warning(f"Unsupported color storage for '{color_key}'. Expected sequence or mapping.")
13791396
return None
13801397

1381-
for i, category in enumerate(categories):
1382-
if i < len(hex_colors) and hex_colors[i] is not None:
1383-
hex_color = hex_colors[i]
1398+
# Map by the category's position in the *full* table, not in the
1399+
# (possibly subset) color_source_vector, so colors stay consistent
1400+
# across coordinate systems.
1401+
all_cats = (
1402+
adata.obs[col_to_colorby].cat.categories.tolist()
1403+
if col_to_colorby in adata.obs and hasattr(adata.obs[col_to_colorby], "cat")
1404+
else categories
1405+
)
1406+
for category in categories:
1407+
idx = all_cats.index(category) if category in all_cats else None
1408+
if idx is not None and idx < len(hex_colors) and hex_colors[idx] is not None:
1409+
hex_color = hex_colors[idx]
13841410
assert hex_color is not None # type narrowing for mypy
13851411
color_mapping[category] = hex_color
13861412
else:

0 commit comments

Comments
 (0)