Skip to content

Commit 8c0d295

Browse files
timtreisclaude
andcommitted
Extract outline rendering and color key building into shared helpers
- _render_ds_outlines: consolidates the 40-line outline aggregation + shading + rendering block (outer + inner) into a single loop - _build_color_key: extracts the identical color key construction from both shapes and points datashader paths The shapes and points datashader pipelines now read nearly identically, differing only in parameter values. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent db51af4 commit 8c0d295

1 file changed

Lines changed: 72 additions & 73 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 72 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,23 @@ def _datashader_shade_continuous(
369369
return ds_result, continuous_nan_shaded, aggregate_with_reduction
370370

371371

372+
def _build_color_key(
373+
transformed_element: Any,
374+
col_for_color: str | None,
375+
color_by_categorical: bool,
376+
color_vector: Any,
377+
na_color_hex: str,
378+
) -> dict[str, str] | None:
379+
"""Build a datashader color key mapping categories to hex colors.
380+
381+
Returns None when not coloring by a categorical column.
382+
"""
383+
if not color_by_categorical or col_for_color is None:
384+
return None
385+
cat_series = _coerce_categorical_source(transformed_element[col_for_color])
386+
return _build_datashader_color_key(cat_series, color_vector, na_color_hex)
387+
388+
372389
def _datashader_shade_categorical(
373390
agg: Any,
374391
color_key: dict[str, str] | None,
@@ -392,6 +409,44 @@ def _datashader_shade_categorical(
392409
)
393410

394411

412+
def _render_ds_outlines(
413+
cvs: Any,
414+
transformed_element: Any,
415+
render_params: ShapesRenderParams,
416+
fig_params: FigParams,
417+
ax: matplotlib.axes.SubplotBase,
418+
factor: float,
419+
extent: list[float],
420+
) -> None:
421+
"""Aggregate, shade, and render shape outlines (outer and inner) with datashader."""
422+
ds_lw_factor = fig_params.fig.dpi / 72
423+
assert len(render_params.outline_alpha) == 2 # noqa: S101
424+
425+
for idx, (outline_color_obj, linewidth) in enumerate(
426+
[
427+
(render_params.outline_params.outer_outline_color, render_params.outline_params.outer_outline_linewidth),
428+
(render_params.outline_params.inner_outline_color, render_params.outline_params.inner_outline_linewidth),
429+
]
430+
):
431+
alpha = render_params.outline_alpha[idx]
432+
if alpha <= 0:
433+
continue
434+
agg_outline = cvs.line(
435+
transformed_element,
436+
geometry="geometry",
437+
line_width=linewidth * ds_lw_factor,
438+
)
439+
if isinstance(outline_color_obj, Color):
440+
shaded = ds.tf.shade(
441+
agg_outline,
442+
cmap=outline_color_obj.get_hex(),
443+
min_alpha=_convert_alpha_to_datashader_range(alpha),
444+
how="linear",
445+
)
446+
rgba, trans = _create_image_from_datashader_result(shaded, factor, ax)
447+
_ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, alpha=alpha, extent=extent)
448+
449+
395450
def _render_datashader_result(
396451
ax: matplotlib.axes.SubplotBase,
397452
ds_result: Any,
@@ -685,31 +740,15 @@ def _render_shapes(
685740
"shapes",
686741
)
687742

688-
# render outlines if needed
689-
# outline_linewidth is in points (1pt = 1/72 inch); datashader line_width is in canvas pixels
690-
ds_lw_factor = fig_params.fig.dpi / 72
691-
assert len(render_params.outline_alpha) == 2 # shut up mypy
692-
if render_params.outline_alpha[0] > 0:
693-
agg_outlines = cvs.line(
694-
transformed_element,
695-
geometry="geometry",
696-
line_width=render_params.outline_params.outer_outline_linewidth * ds_lw_factor,
697-
)
698-
if render_params.outline_alpha[1] > 0:
699-
agg_inner_outlines = cvs.line(
700-
transformed_element,
701-
geometry="geometry",
702-
line_width=render_params.outline_params.inner_outline_linewidth * ds_lw_factor,
703-
)
704-
705743
agg, ds_span = _apply_datashader_norm(agg, norm)
706-
707-
color_key: dict[str, str] | None = None
708-
if color_by_categorical and col_for_color is not None:
709-
cat_series = _coerce_categorical_source(transformed_element[col_for_color])
710-
color_key = _build_datashader_color_key(
711-
cat_series, color_vector, render_params.cmap_params.na_color.get_hex()
712-
)
744+
na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
745+
color_key = _build_color_key(
746+
transformed_element,
747+
col_for_color,
748+
color_by_categorical,
749+
color_vector,
750+
na_color_hex,
751+
)
713752

714753
continuous_nan_shaded = None
715754
if color_by_categorical or col_for_color is None:
@@ -720,7 +759,6 @@ def _render_shapes(
720759
render_params.fill_alpha,
721760
)
722761
else:
723-
na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
724762
ds_result, continuous_nan_shaded, aggregate_with_reduction = _datashader_shade_continuous(
725763
agg,
726764
ds_span,
@@ -732,46 +770,7 @@ def _render_shapes(
732770
na_color_hex,
733771
)
734772

735-
# shade outlines if needed
736-
if render_params.outline_alpha[0] > 0 and isinstance(render_params.outline_params.outer_outline_color, Color):
737-
outline_color = render_params.outline_params.outer_outline_color.get_hex()
738-
ds_outlines = ds.tf.shade(
739-
agg_outlines,
740-
cmap=outline_color,
741-
min_alpha=_convert_alpha_to_datashader_range(render_params.outline_alpha[0]),
742-
how="linear",
743-
)
744-
# inner outlines
745-
if render_params.outline_alpha[1] > 0 and isinstance(render_params.outline_params.inner_outline_color, Color):
746-
outline_color = render_params.outline_params.inner_outline_color.get_hex()
747-
ds_inner_outlines = ds.tf.shade(
748-
agg_inner_outlines,
749-
cmap=outline_color,
750-
min_alpha=_convert_alpha_to_datashader_range(render_params.outline_alpha[1]),
751-
how="linear",
752-
)
753-
754-
# render outline image(s)
755-
if render_params.outline_alpha[0] > 0:
756-
rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax)
757-
_ax_show_and_transform(
758-
rgba_image,
759-
trans_data,
760-
ax,
761-
zorder=render_params.zorder,
762-
alpha=render_params.outline_alpha[0],
763-
extent=x_ext + y_ext,
764-
)
765-
if render_params.outline_alpha[1] > 0:
766-
rgba_image, trans_data = _create_image_from_datashader_result(ds_inner_outlines, factor, ax)
767-
_ax_show_and_transform(
768-
rgba_image,
769-
trans_data,
770-
ax,
771-
zorder=render_params.zorder,
772-
alpha=render_params.outline_alpha[1],
773-
extent=x_ext + y_ext,
774-
)
773+
_render_ds_outlines(cvs, transformed_element, render_params, fig_params, ax, factor, x_ext + y_ext)
775774

776775
_cax = _render_datashader_result(
777776
ax,
@@ -1133,13 +1132,14 @@ def _render_points(
11331132
)
11341133

11351134
agg, ds_span = _apply_datashader_norm(agg, norm)
1136-
1137-
color_key: dict[str, str] | None = None
1138-
if color_by_categorical and col_for_color is not None:
1139-
cat_series = _coerce_categorical_source(transformed_element[col_for_color])
1140-
color_key = _build_datashader_color_key(
1141-
cat_series, color_vector, render_params.cmap_params.na_color.get_hex()
1142-
)
1135+
na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
1136+
color_key = _build_color_key(
1137+
transformed_element,
1138+
col_for_color,
1139+
color_by_categorical,
1140+
color_vector,
1141+
na_color_hex,
1142+
)
11431143

11441144
if (
11451145
color_vector is not None
@@ -1159,7 +1159,6 @@ def _render_points(
11591159
spread_px=px,
11601160
)
11611161
else:
1162-
na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
11631162
ds_result, continuous_nan_shaded, aggregate_with_reduction = _datashader_shade_continuous(
11641163
agg,
11651164
ds_span,

0 commit comments

Comments
 (0)