Skip to content

Commit 316e2ca

Browse files
authored
fix: allow list palette without groups (#659)
Signed-off-by: SAY-5 <say.apm35@gmail.com>
1 parent f4dd082 commit 316e2ca

2 files changed

Lines changed: 18 additions & 3 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,9 +2479,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
24792479
palette_group = param_dict.get("palette")
24802480
if element_type in ["shapes", "points", "labels"] and palette_group is not None and not isinstance(palette, dict):
24812481
groups = param_dict.get("groups")
2482-
if groups is None:
2483-
raise ValueError("When specifying 'palette', 'groups' must also be specified.")
2484-
if len(groups) != len(palette_group):
2482+
if groups is not None and len(groups) != len(palette_group):
24852483
raise ValueError(
24862484
f"The length of 'palette' and 'groups' must be the same, length is {len(palette_group)} and"
24872485
f"{len(groups)} respectively."

tests/pl/test_render_shapes.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,23 @@ def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
275275
"blobs_polygons", color="cluster", groups=["c2", "c1"], palette=["green", "yellow"]
276276
).pl.show()
277277

278+
def test_render_shapes_list_palette_without_groups(self, sdata_blobs: SpatialData):
279+
# Regression test for #605: a list palette should map to categories in their natural order
280+
# without requiring groups= to enumerate every category.
281+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
282+
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
283+
sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1"
284+
sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2"
285+
sdata_blobs.shapes["blobs_polygons"]["cluster"] = sdata_blobs.shapes["blobs_polygons"]["cluster"].astype(
286+
"category"
287+
)
288+
289+
_, ax = plt.subplots()
290+
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", palette=["green", "yellow"]).pl.show(ax=ax)
291+
legend = ax.get_legend()
292+
assert legend is not None
293+
assert {t.get_text() for t in legend.get_texts()} == {"c1", "c2"}
294+
278295
def test_plot_colorbar_respects_input_limits(self, sdata_blobs: SpatialData):
279296
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
280297
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"

0 commit comments

Comments
 (0)