Skip to content

Commit a782851

Browse files
timtreisclaude
andcommitted
Fix categorical colors wrongly assigned to points with non-sequential index (#358)
When points have a shuffled or non-sequential index (e.g. from .sample() or .subset()), _reparse_points sorts rows by index while adata.X retains the original order. This causes get_values to return colors in sorted order, misaligned with coordinates. Resetting the index to sequential before adata construction and reparsing ensures both share the same positional order. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 55d59b7 commit a782851

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,9 @@ def _render_points(
743743
)
744744
added_color_from_table = True
745745

746+
# Reset to sequential index so row order matches after _reparse_points round-trip (#358).
747+
points = points.reset_index(drop=True)
748+
746749
n_points = len(points)
747750
points_pd_with_color = points
748751
# When we pull colors from a table, keep the raw points (with color) for later,
@@ -758,7 +761,7 @@ def _render_points(
758761
if table_name is None:
759762
adata = AnnData(
760763
X=points[["x", "y"]].values,
761-
obs=points[coords].reset_index(),
764+
obs=points[coords],
762765
dtype=points[["x", "y"]].values.dtype,
763766
)
764767
else:

tests/pl/test_render_points.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,41 @@ def test_plot_groups_na_color_none_filters_points_datashader(self, sdata_blobs:
607607
).pl.show(ax=axs[1], title="default (filtered)")
608608

609609

610+
def test_shuffled_index_categorical_color_alignment():
611+
"""Regression test for #358: categorical colors must follow the data, not the index order."""
612+
n = 100
613+
rng = get_standard_RNG()
614+
x = np.concatenate([rng.uniform(0, 10, n // 2), rng.uniform(90, 100, n // 2)])
615+
y = np.concatenate([rng.uniform(0, 10, n // 2), rng.uniform(90, 100, n // 2)])
616+
df = pd.DataFrame(
617+
{
618+
"x": x,
619+
"y": y,
620+
"cluster": pd.Categorical(["A"] * (n // 2) + ["B"] * (n // 2)),
621+
}
622+
)
623+
pts = PointsModel.parse(df)
624+
sdata = SpatialData(points={"pts": pts})
625+
626+
# .sample() produces a non-sequential, shuffled index — the trigger for #358.
627+
sampled = sdata.points["pts"].compute().sample(frac=0.8, random_state=42)
628+
sdata.points["pts"] = PointsModel.parse(sampled)
629+
630+
_, ax = plt.subplots()
631+
sdata.pl.render_points("pts", color="cluster", method="matplotlib", size=20).pl.show(ax=ax)
632+
633+
colls = [c for c in ax.collections if hasattr(c, "get_offsets") and len(c.get_offsets()) > 0]
634+
assert colls, "expected scatter points"
635+
offsets = colls[-1].get_offsets()
636+
colors = colls[-1].get_facecolors()
637+
left_colors = np.unique(colors[offsets[:, 0] < 50], axis=0)
638+
right_colors = np.unique(colors[offsets[:, 0] > 50], axis=0)
639+
assert len(left_colors) == 1, f"left cluster should have 1 color, got {len(left_colors)}"
640+
assert len(right_colors) == 1, f"right cluster should have 1 color, got {len(right_colors)}"
641+
assert not np.array_equal(left_colors[0], right_colors[0]), "clusters should have different colors"
642+
plt.close("all")
643+
644+
610645
def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
611646
"""When no elements match the groups, the plot should render without error."""
612647
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")

0 commit comments

Comments
 (0)