Skip to content

Commit 2329522

Browse files
authored
Fix datashader resolution collapse on cropped coordinate offsets (#669)
1 parent 59f1782 commit 2329522

14 files changed

Lines changed: 239 additions & 74 deletions

src/spatialdata_plot/pl/_datashader.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ def _render_ds_image(
306306
shaded: Any,
307307
factor: float,
308308
zorder: int,
309-
extent: list[float] | None,
309+
x_min: float = 0.0,
310+
y_min: float = 0.0,
310311
nan_result: Any | None = None,
311312
) -> Any:
312313
"""Render a shaded datashader image onto matplotlib axes, with optional NaN overlay.
@@ -316,10 +317,10 @@ def _render_ds_image(
316317
it again would apply transparency twice (see #367).
317318
"""
318319
if nan_result is not None:
319-
rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax)
320-
_ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, extent=extent)
321-
rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax)
322-
return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, extent=extent)
320+
rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax, x_min, y_min)
321+
_ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder)
322+
rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax, x_min, y_min)
323+
return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder)
323324

324325

325326
def _render_ds_outlines(
@@ -329,7 +330,8 @@ def _render_ds_outlines(
329330
fig_params: FigParams,
330331
ax: matplotlib.axes.SubplotBase,
331332
factor: float,
332-
extent: list[float],
333+
x_min: float = 0.0,
334+
y_min: float = 0.0,
333335
) -> None:
334336
"""Aggregate, shade, and render shape outlines (outer and inner) with datashader."""
335337
ds_lw_factor = fig_params.fig.dpi / 72
@@ -357,8 +359,8 @@ def _render_ds_outlines(
357359
how="linear",
358360
)
359361
shaded = _apply_user_alpha(shaded, alpha)
360-
rgba, trans = _create_image_from_datashader_result(shaded, factor, ax)
361-
_ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, extent=extent)
362+
rgba, trans = _create_image_from_datashader_result(shaded, factor, ax, x_min, y_min)
363+
_ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder)
362364

363365

364366
def _build_ds_colorbar(

src/spatialdata_plot/pl/render.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,13 @@ def _render_shapes(
526526
)
527527
)
528528

529+
if len(transformed_element) == 0:
530+
# Nothing to rasterize (e.g., a bounding_box_query that matched no
531+
# shapes). Skip the datashader pipeline.
532+
return
533+
529534
plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
530-
transformed_element, "global", ax, fig_params
535+
transformed_element, "global", fig_params
531536
)
532537

533538
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)
@@ -608,15 +613,17 @@ def _render_shapes(
608613
fig_params,
609614
ax,
610615
factor,
611-
x_ext + y_ext,
616+
x_min=x_ext[0],
617+
y_min=y_ext[0],
612618
)
613619

614620
_cax = _render_ds_image(
615621
ax,
616622
shaded,
617623
factor,
618624
render_params.zorder,
619-
x_ext + y_ext,
625+
x_min=x_ext[0],
626+
y_min=y_ext[0],
620627
nan_result=nan_shaded,
621628
)
622629

@@ -939,8 +946,14 @@ def _render_points(
939946
transformations={coordinate_system: Identity()},
940947
).compute()
941948

949+
if len(transformed_element) == 0:
950+
# Nothing to rasterize (e.g., a bounding_box_query that matched no
951+
# points). Skip the datashader pipeline; rendering proceeds with
952+
# any other elements on the axes.
953+
return
954+
942955
plot_width, plot_height, x_ext, y_ext, factor = _datashader_canvas_from_dataframe(
943-
transformed_element, ax, fig_params
956+
transformed_element, fig_params
944957
)
945958

946959
# use datashader for the visualization of points
@@ -1036,7 +1049,8 @@ def _render_points(
10361049
shaded,
10371050
factor,
10381051
render_params.zorder,
1039-
x_ext + y_ext,
1052+
x_min=x_ext[0],
1053+
y_min=y_ext[0],
10401054
nan_result=nan_shaded,
10411055
)
10421056

src/spatialdata_plot/pl/utils.py

Lines changed: 37 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@
6666
from spatialdata._types import ArrayLike
6767
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys
6868
from spatialdata.transformations.operations import get_transformation
69-
from spatialdata.transformations.transformations import Scale
69+
from spatialdata.transformations.transformations import Scale, Translation
70+
from spatialdata.transformations.transformations import Sequence as TransformSequence
7071
from xarray import DataArray, DataTree
7172

7273
from spatialdata_plot._logging import logger
@@ -2029,7 +2030,12 @@ def _rasterize_if_necessary(
20292030

20302031
if do_rasterization:
20312032
logger.info("Rasterizing image for faster rendering.")
2032-
target_unit_to_pixels = min(target_y_dims / y_dims, target_x_dims / x_dims)
2033+
# ``rasterize`` interprets ``target_unit_to_pixels`` in world units, not
2034+
# intrinsic pixels. Dividing by world extent keeps the result correct
2035+
# for any transformation (translation, scale, etc.).
2036+
world_x = float(extent["x"][1]) - float(extent["x"][0])
2037+
world_y = float(extent["y"][1]) - float(extent["y"][0])
2038+
target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x)
20332039
image = rasterize(
20342040
image,
20352041
("y", "x"),
@@ -3378,42 +3384,20 @@ def _ax_show_and_transform(
33783384
alpha: float | None = None,
33793385
cmap: ListedColormap | LinearSegmentedColormap | None = None,
33803386
zorder: int = 0,
3381-
extent: list[float] | None = None,
33823387
norm: Normalize | None = None,
33833388
) -> matplotlib.image.AxesImage:
3384-
# default extent in mpl:
3385-
image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5]
3386-
if extent is not None:
3387-
# make sure extent is [x_min, x_max, y_min, y_max]
3388-
if extent[3] < extent[2]:
3389-
extent[2], extent[3] = extent[3], extent[2]
3390-
if extent[0] < 0:
3391-
x_factor = array.shape[1] / (extent[1] - extent[0])
3392-
image_extent[0] = image_extent[0] + (extent[0] * x_factor)
3393-
image_extent[1] = image_extent[1] + (extent[0] * x_factor)
3394-
if extent[2] < 0:
3395-
y_factor = array.shape[0] / (extent[3] - extent[2])
3396-
image_extent[2] = image_extent[2] + (extent[2] * y_factor)
3397-
image_extent[3] = image_extent[3] + (extent[2] * y_factor)
3398-
3389+
# ``extent`` uses mpl's pixel-grid convention; world placement happens via
3390+
# ``set_transform(trans_data)`` afterwards.
3391+
image_extent = (-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5)
3392+
# ``alpha`` is applied only when no cmap is set, so RGBA arrays already
3393+
# carrying per-pixel alpha (e.g. datashader output) are not double-attenuated.
3394+
imshow_kwargs: dict[str, Any] = {"zorder": zorder, "extent": image_extent, "norm": norm}
33993395
if not cmap and alpha is not None:
3400-
im = ax.imshow(
3401-
array,
3402-
alpha=alpha,
3403-
zorder=zorder,
3404-
extent=tuple(image_extent),
3405-
norm=norm,
3406-
)
3407-
im.set_transform(trans_data)
3396+
imshow_kwargs["alpha"] = alpha
34083397
else:
3409-
im = ax.imshow(
3410-
array,
3411-
cmap=cmap,
3412-
zorder=zorder,
3413-
extent=tuple(image_extent),
3414-
norm=norm,
3415-
)
3416-
im.set_transform(trans_data)
3398+
imshow_kwargs["cmap"] = cmap
3399+
im = ax.imshow(array, **imshow_kwargs)
3400+
im.set_transform(trans_data)
34173401
return im
34183402

34193403

@@ -3442,30 +3426,12 @@ def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = No
34423426
def _compute_datashader_canvas_params(
34433427
x_ext: list[Any],
34443428
y_ext: list[Any],
3445-
ax: Axes,
34463429
fig_params: FigParams,
34473430
) -> tuple[Any, Any, list[Any], list[Any], Any]:
34483431
"""Compute datashader canvas dimensions from spatial extents.
34493432
34503433
Shared logic used by both the dask-based and pandas-based entry points.
34513434
"""
3452-
previous_xlim = ax.get_xlim()
3453-
previous_ylim = ax.get_ylim()
3454-
# increase range if sth larger was rendered on the axis before
3455-
if _mpl_ax_contains_elements(ax):
3456-
x_ext = [min(x_ext[0], previous_xlim[0]), max(x_ext[1], previous_xlim[1])]
3457-
y_ext = (
3458-
[
3459-
min(y_ext[0], previous_ylim[1]),
3460-
max(y_ext[1], previous_ylim[0]),
3461-
]
3462-
if ax.yaxis_inverted()
3463-
else [
3464-
min(y_ext[0], previous_ylim[0]),
3465-
max(y_ext[1], previous_ylim[1]),
3466-
]
3467-
)
3468-
34693435
# Compute canvas size in pixels, capped at the figure's display resolution.
34703436
# Using np.max ensures the canvas never exceeds display pixels on either axis,
34713437
# preventing pixel-based operations (spread, line_width) from being downscaled
@@ -3485,42 +3451,52 @@ def _compute_datashader_canvas_params(
34853451
def _get_extent_and_range_for_datashader_canvas(
34863452
spatial_element: SpatialElement,
34873453
coordinate_system: str,
3488-
ax: Axes,
34893454
fig_params: FigParams,
34903455
) -> tuple[Any, Any, list[Any], list[Any], Any]:
34913456
extent = get_extent(spatial_element, coordinate_system=coordinate_system)
3492-
x_ext = [min(0, extent["x"][0]), extent["x"][1]]
3493-
y_ext = [min(0, extent["y"][0]), extent["y"][1]]
3494-
return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params)
3457+
x_ext = [float(extent["x"][0]), float(extent["x"][1])]
3458+
y_ext = [float(extent["y"][0]), float(extent["y"][1])]
3459+
return _compute_datashader_canvas_params(x_ext, y_ext, fig_params)
34953460

34963461

34973462
def _datashader_canvas_from_dataframe(
34983463
df: pd.DataFrame,
3499-
ax: Axes,
35003464
fig_params: FigParams,
35013465
) -> tuple[Any, Any, list[Any], list[Any], Any]:
35023466
"""Compute datashader canvas params directly from a pandas DataFrame.
35033467
35043468
Avoids the overhead of ``get_extent()`` (which requires a dask-backed
35053469
SpatialElement) by reading min/max from the already-materialised data.
35063470
"""
3507-
x_ext = [min(0, float(df["x"].min())), float(df["x"].max())]
3508-
y_ext = [min(0, float(df["y"].min())), float(df["y"].max())]
3509-
return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params)
3471+
if len(df) == 0:
3472+
# Empty input (e.g., a bounding_box_query with no overlap) — caller
3473+
# should short-circuit; return zero-sized canvas params as a sentinel.
3474+
return 0, 0, [0.0, 0.0], [0.0, 0.0], 1.0
3475+
x_ext = [float(df["x"].min()), float(df["x"].max())]
3476+
y_ext = [float(df["y"].min()), float(df["y"].max())]
3477+
return _compute_datashader_canvas_params(x_ext, y_ext, fig_params)
35103478

35113479

35123480
def _create_image_from_datashader_result(
35133481
ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]],
35143482
factor: float,
35153483
ax: Axes,
3484+
x_min: float = 0.0,
3485+
y_min: float = 0.0,
35163486
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]:
35173487
# create SpatialImage from datashader output to get it back to original size
35183488
rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base
35193489
rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1))
3490+
transformation: Scale | TransformSequence = Scale([1, factor, factor], ("c", "y", "x"))
3491+
if x_min != 0.0 or y_min != 0.0:
3492+
# Canvas pixel (0, 0) corresponds to world (x_min, y_min). Without this
3493+
# translation the rgba would render at the world origin instead of at
3494+
# the element's actual position.
3495+
transformation = TransformSequence([transformation, Translation([x_min, y_min], ("x", "y"))])
35203496
rgba_image = Image2DModel.parse(
35213497
rgba_image_data,
35223498
dims=("c", "y", "x"),
3523-
transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))},
3499+
transformations={"global": transformation},
35243500
)
35253501

35263502
_, trans_data = _prepare_transformation(rgba_image, "global", ax)
6.36 KB
Loading
-2.88 KB
Loading
-1.05 KB
Loading
-560 Bytes
Loading
-697 Bytes
Loading
-629 Bytes
Loading
-541 Bytes
Loading

0 commit comments

Comments
 (0)