Skip to content

Commit 0f0db99

Browse files
timtreisclaude
andcommitted
Fix categorical colors wrongly assigned to points with non-sequential index (#358)
When points have a shuffled or non-sequential index (e.g. from .sample() or .subset()), _reparse_points sorts rows by index while adata.X retains the original order. This causes get_values to return colors in sorted order, misaligned with coordinates. Resetting the index to sequential before adata construction and reparsing ensures both share the same positional order. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 55d59b7 commit 0f0db99

2 files changed

Lines changed: 44 additions & 1 deletion

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,9 @@ def _render_points(
743743
)
744744
added_color_from_table = True
745745

746+
# Reset to sequential index so row order matches after _reparse_points round-trip (#358).
747+
points = points.reset_index(drop=True)
748+
746749
n_points = len(points)
747750
points_pd_with_color = points
748751
# When we pull colors from a table, keep the raw points (with color) for later,
@@ -758,7 +761,7 @@ def _render_points(
758761
if table_name is None:
759762
adata = AnnData(
760763
X=points[["x", "y"]].values,
761-
obs=points[coords].reset_index(),
764+
obs=points[coords],
762765
dtype=points[["x", "y"]].values.dtype,
763766
)
764767
else:

tests/pl/test_render_points.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,46 @@ def test_plot_groups_na_color_none_filters_points_datashader(self, sdata_blobs:
607607
).pl.show(ax=axs[1], title="default (filtered)")
608608

609609

610+
@pytest.mark.parametrize("method", ["matplotlib", "datashader"])
611+
def test_shuffled_index_categorical_color_alignment(method: str):
612+
"""Regression test for #358: categorical colors must follow the data, not the index order."""
613+
# Two spatially separated clusters so correct/incorrect coloring is distinguishable.
614+
n = 100
615+
rng = get_standard_RNG()
616+
x = np.concatenate([rng.uniform(0, 10, n // 2), rng.uniform(90, 100, n // 2)])
617+
y = np.concatenate([rng.uniform(0, 10, n // 2), rng.uniform(90, 100, n // 2)])
618+
df = pd.DataFrame(
619+
{
620+
"x": x,
621+
"y": y,
622+
"cluster": pd.Categorical(["A"] * (n // 2) + ["B"] * (n // 2)),
623+
}
624+
)
625+
# Shuffle rows so the index is non-sequential (simulates .sample() / .subset()).
626+
shuffled = df.sample(frac=1, random_state=42)
627+
assert shuffled.index.tolist() != list(range(n)), "sanity: index should be shuffled"
628+
629+
pts = PointsModel.parse(shuffled)
630+
sdata = SpatialData(points={"pts": pts})
631+
632+
_, ax = plt.subplots()
633+
sdata.pl.render_points("pts", color="cluster", method=method, size=20).pl.show(ax=ax)
634+
635+
# For datashader we can only check it doesn't error; for matplotlib we can
636+
# inspect the scatter colors directly.
637+
if method == "matplotlib":
638+
colls = [c for c in ax.collections if hasattr(c, "get_offsets") and len(c.get_offsets()) > 0]
639+
assert colls, "expected scatter points"
640+
offsets = colls[-1].get_offsets()
641+
colors = colls[-1].get_facecolors()
642+
left_colors = np.unique(colors[offsets[:, 0] < 50], axis=0)
643+
right_colors = np.unique(colors[offsets[:, 0] > 50], axis=0)
644+
assert len(left_colors) == 1, f"left cluster should have 1 color, got {len(left_colors)}"
645+
assert len(right_colors) == 1, f"right cluster should have 1 color, got {len(right_colors)}"
646+
assert not np.array_equal(left_colors[0], right_colors[0]), "clusters should have different colors"
647+
plt.close("all")
648+
649+
610650
def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
611651
"""When no elements match the groups, the plot should render without error."""
612652
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")

0 commit comments

Comments
 (0)