Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1563701
feat(_transforms): create private transforms package skeleton
d-v-b May 6, 2026
273eed9
feat(_transforms): port output_map module
d-v-b May 6, 2026
29367e2
feat(_transforms): port domain module
d-v-b May 6, 2026
2e7b646
feat(_transforms): port transform module
d-v-b May 6, 2026
34b0384
feat(_transforms): port composition module
d-v-b May 6, 2026
5a5ad23
feat(_transforms): port chunk_resolution module
d-v-b May 6, 2026
5370f91
feat(_transforms): expose package exports
d-v-b May 6, 2026
fca2926
docs(_transforms): use markdown single-backticks; fix stale lazy acce…
d-v-b May 7, 2026
cf56e5b
test(_transforms): rewrite output_map tests in parametrized style
d-v-b May 7, 2026
6e54443
test(_transforms): rewrite domain tests in parametrized style
d-v-b May 7, 2026
d91f592
refactor(_transforms): drop unused _normalize_negative_indices
d-v-b May 7, 2026
6d18502
test(_transforms): rewrite transform tests in parametrized style
d-v-b May 7, 2026
c25bf7b
test(_transforms): rewrite composition tests in parametrized style
d-v-b May 7, 2026
20df22c
test(_transforms): rewrite chunk_resolution tests in parametrized style
d-v-b May 7, 2026
2d659a1
test(_transforms): fix mypy errors in cross-file mypy run
d-v-b May 7, 2026
f7d05da
docs(_transforms): convert remaining RST code block to markdown in do…
d-v-b May 7, 2026
895b1df
test(_transforms): add error tests for previously-untested public bra…
d-v-b May 7, 2026
edae196
test(_transforms): add hypothesis property test for compose associati…
d-v-b May 7, 2026
2a10850
refactor(_transforms): make ArrayMap.input_dimensions explicit
d-v-b May 7, 2026
71819c3
Merge branch 'main' into worktree-lazy-indexing-pr1-transforms
d-v-b May 8, 2026
c4de023
refactor(_transforms): hoist itertools import in chunk_resolution
d-v-b May 8, 2026
02a5546
Merge branch 'worktree-lazy-indexing-pr1-transforms' of https://githu…
d-v-b May 8, 2026
ee91e0b
fix(_transforms): guard multi-dim ArrayMap case in oindex; address re…
d-v-b May 8, 2026
43a9b98
test(_transforms): close coverage gaps; reach 99% line+branch
d-v-b May 8, 2026
ea661d4
docs(domain): add import in docstring
d-v-b May 8, 2026
542036f
test(_transforms): reach 100% line+branch coverage
d-v-b May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/zarr/core/_transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Composable, lazy coordinate transforms for zarr array indexing.

This package implements TensorStore-inspired index transforms. The core idea:
every indexing operation (slicing, fancy indexing, etc.) produces a coordinate
mapping from user space to storage space. These mappings compose lazily - no
I/O until you explicitly read or write.

Private package: this module is not part of the public zarr API. The leading
underscore in the package name signals this. Importers outside this package
must be limited to other private zarr modules.

Key types:

- `IndexDomain` -- a rectangular region of integer coordinates
- `IndexTransform` -- maps input coordinates to storage coordinates
- `ConstantMap`, `DimensionMap`, `ArrayMap` -- the three ways a single
output dimension can depend on the input (see `output_map.py`)
- `compose` -- chain two transforms into one
"""

from zarr.core._transforms.composition import compose
from zarr.core._transforms.domain import IndexDomain
from zarr.core._transforms.output_map import (
ArrayMap,
ConstantMap,
DimensionMap,
OutputIndexMap,
)
from zarr.core._transforms.transform import IndexTransform

__all__ = [
"ArrayMap",
"ConstantMap",
"DimensionMap",
"IndexDomain",
"IndexTransform",
"OutputIndexMap",
"compose",
]
210 changes: 210 additions & 0 deletions src/zarr/core/_transforms/chunk_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""Chunk resolution — mapping transforms to chunk-level I/O.

Given an `IndexTransform` (which coordinates a user wants to access) and a
`ChunkGrid` (how storage is divided into chunks), chunk resolution answers:

For each chunk, which storage coordinates does this transform touch,
and where do those values land in the output buffer?

The algorithm is:

1. **Enumerate candidate chunks** — determine which chunks could possibly
be touched by the transform's output coordinate ranges.

2. **Intersect** — for each candidate chunk, call
`transform.intersect(chunk_domain)` to restrict the transform to
coordinates within that chunk. If the intersection is empty, skip it.

3. **Translate** — shift the restricted transform to chunk-local coordinates
via `transform.translate(-chunk_origin)`.

4. **Yield** — produce `(chunk_coords, local_transform, surviving_indices)`
triples that the codec pipeline consumes.

`sub_transform_to_selections` bridges from the transform representation
back to the raw `(chunk_selection, out_selection, drop_axes)` tuples that
the current codec pipeline expects. This bridge will go away when the codec
pipeline accepts transforms natively.
"""

from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Any

import numpy as np

from zarr.core._transforms.domain import IndexDomain
from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap
from zarr.core._transforms.transform import IndexTransform

if TYPE_CHECKING:
from collections.abc import Iterator

from zarr.core.chunk_grids import ChunkGrid

ChunkTransformResult = tuple[
tuple[int, ...],
IndexTransform,
np.ndarray[Any, np.dtype[np.intp]] | None,
]


def iter_chunk_transforms(
transform: IndexTransform,
chunk_grid: ChunkGrid,
) -> Iterator[ChunkTransformResult]:
"""Resolve a composed IndexTransform against a ChunkGrid.

Yields `(chunk_coords, sub_transform, out_indices)` triples:

- `chunk_coords`: which chunk to access.
- `sub_transform`: maps output buffer coords to chunk-local coords.
- `out_indices`: for vectorized/array indexing, the output scatter
indices (integer array). `None` for basic/slice indexing.
"""
dim_grids = chunk_grid._dimensions

# Enumerate all possible chunks via cartesian product of per-dim chunk ranges
# For each candidate chunk, intersect the transform with the chunk domain.
# The transform.intersect method handles both orthogonal and vectorized cases.
chunk_ranges: list[range] = []
for out_dim, m in enumerate(transform.output):
dg = dim_grids[out_dim]
if isinstance(m, ConstantMap):
# Single chunk
c = dg.index_to_chunk(m.offset)
chunk_ranges.append(range(c, c + 1))
elif isinstance(m, DimensionMap):
d = m.input_dimension
dim_lo = transform.domain.inclusive_min[d]
dim_hi = transform.domain.exclusive_max[d]
if dim_lo >= dim_hi:
return # empty domain
# DimensionMap.stride is always positive (enforced by __post_init__).
s_min = m.offset + m.stride * dim_lo
s_max = m.offset + m.stride * (dim_hi - 1)
first = dg.index_to_chunk(s_min)
last = dg.index_to_chunk(s_max)
chunk_ranges.append(range(first, last + 1))
elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union
storage = m.offset + m.stride * m.index_array
flat = storage.ravel().astype(np.intp)
chunk_ids = dg.indices_to_chunks(flat)
first = int(chunk_ids.min())
last = int(chunk_ids.max())
chunk_ranges.append(range(first, last + 1))

for chunk_coords_tuple in itertools.product(*chunk_ranges):
chunk_coords = tuple(int(c) for c in chunk_coords_tuple)

# Build the chunk domain in storage space
chunk_min: list[int] = []
chunk_max: list[int] = []
chunk_shift: list[int] = []
for out_dim, c in enumerate(chunk_coords):
dg = dim_grids[out_dim]
c_start = dg.chunk_offset(c)
c_size = dg.chunk_size(c)
chunk_min.append(c_start)
chunk_max.append(c_start + c_size)
chunk_shift.append(-c_start)

chunk_domain = IndexDomain(
inclusive_min=tuple(chunk_min),
exclusive_max=tuple(chunk_max),
)

# Intersect transform with chunk domain
result = transform.intersect(chunk_domain)
if result is None:
continue

restricted, surviving = result

# Translate to chunk-local coordinates
local = restricted.translate(tuple(chunk_shift))

yield (chunk_coords, local, surviving)


def sub_transform_to_selections(
sub_transform: IndexTransform,
out_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None,
) -> tuple[
tuple[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...],
tuple[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...],
tuple[int, ...],
]:
"""Convert a chunk-local sub-transform to raw selections for the codec pipeline.

Parameters
----------
sub_transform
A chunk-local IndexTransform (output maps already translated to
chunk-local coordinates).
out_indices
For vectorized indexing: the output scatter indices for this chunk.
None for orthogonal/basic indexing.

Returns
-------
tuple
`(chunk_selection, out_selection, drop_axes)`
"""
chunk_sel: list[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = []
drop_axes: list[int] = []

for m in sub_transform.output:
if isinstance(m, ConstantMap):
chunk_sel.append(m.offset)
elif isinstance(m, DimensionMap):
# DimensionMap.stride is always positive (enforced by __post_init__).
dim_lo = sub_transform.domain.inclusive_min[m.input_dimension]
dim_hi = sub_transform.domain.exclusive_max[m.input_dimension]
start = m.offset + m.stride * dim_lo
stop = m.offset + m.stride * dim_hi
chunk_sel.append(slice(start, stop, m.stride))
elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union
if m.offset == 0 and m.stride == 1:
chunk_sel.append(m.index_array)
else:
storage_coords = m.offset + m.stride * m.index_array
chunk_sel.append(storage_coords.astype(np.intp))

# Build out_sel: one entry per non-dropped output dim.
out_sel: list[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = []

# Vectorized: 2+ ArrayMaps that share at least one input dimension are
# correlated; they all index into a single shared scatter array.
is_vectorized = False
if out_indices is not None:
seen_input_dims: set[int] = set()
for m in sub_transform.output:
if isinstance(m, ArrayMap):
if seen_input_dims & set(m.input_dimensions):
is_vectorized = True
break
seen_input_dims.update(m.input_dimensions)

if is_vectorized:
assert out_indices is not None
out_sel.append(out_indices)
else:
for m in sub_transform.output:
if isinstance(m, ConstantMap):
continue
if isinstance(m, DimensionMap):
lo = sub_transform.domain.inclusive_min[m.input_dimension]
hi = sub_transform.domain.exclusive_max[m.input_dimension]
out_sel.append(slice(lo, hi))
elif isinstance(
m, ArrayMap
): # pragma: no branch - exhaustive over OutputIndexMap union
if out_indices is not None:
# Orthogonal ArrayMap: out_indices has the surviving positions
out_sel.append(out_indices)
else:
out_sel.append(slice(0, len(m.index_array)))

return tuple(chunk_sel), tuple(out_sel), tuple(drop_axes)
137 changes: 137 additions & 0 deletions src/zarr/core/_transforms/composition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import annotations

import numpy as np

from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap
from zarr.core._transforms.transform import IndexTransform


def compose(outer: IndexTransform, inner: IndexTransform) -> IndexTransform:
"""Compose two IndexTransforms.

`outer` maps user coords (rank m) to intermediate coords (rank n).
`inner` maps intermediate coords (rank n) to storage coords (rank p).
The result maps user coords (rank m) to storage coords (rank p).

Precondition: `outer.output_rank == inner.domain.ndim`.
"""
if outer.output_rank != inner.domain.ndim:
raise ValueError(
f"outer output rank ({outer.output_rank}) must match inner input rank "
f"({inner.domain.ndim})"
)

result_output = [_compose_single(outer, inner_map) for inner_map in inner.output]

return IndexTransform(domain=outer.domain, output=tuple(result_output))


def _compose_single(outer: IndexTransform, inner_map: OutputIndexMap) -> OutputIndexMap:
"""Compose a single inner output map with the full outer transform."""
if isinstance(inner_map, ConstantMap):
return ConstantMap(offset=inner_map.offset)

if isinstance(inner_map, DimensionMap):
return _compose_dimension(outer, inner_map)

if isinstance(inner_map, ArrayMap):
return _compose_array(outer, inner_map)

raise TypeError(f"Unknown output map type: {type(inner_map)}") # pragma: no cover


def _compose_dimension(outer: IndexTransform, inner_map: DimensionMap) -> OutputIndexMap:
"""Compose when inner is a DimensionMap.

storage = offset_i + stride_i * intermediate[dim_i]
where intermediate[dim_i] = outer.output[dim_i](user_input)
"""
dim_i = inner_map.input_dimension
offset_i = inner_map.offset
stride_i = inner_map.stride
outer_map = outer.output[dim_i]

if isinstance(outer_map, ConstantMap):
return ConstantMap(offset=offset_i + stride_i * outer_map.offset)

if isinstance(outer_map, DimensionMap):
return DimensionMap(
input_dimension=outer_map.input_dimension,
offset=offset_i + stride_i * outer_map.offset,
stride=stride_i * outer_map.stride,
)

if isinstance(outer_map, ArrayMap):
return ArrayMap(
index_array=outer_map.index_array,
input_dimensions=outer_map.input_dimensions,
offset=offset_i + stride_i * outer_map.offset,
stride=stride_i * outer_map.stride,
)

raise TypeError(f"Unknown output map type: {type(outer_map)}") # pragma: no cover


def _compose_array(outer: IndexTransform, inner_map: ArrayMap) -> OutputIndexMap:
"""Compose when inner is an ArrayMap.

storage = offset_i + stride_i * arr_i[intermediate[input_dimensions[0]],
intermediate[input_dimensions[1]], ...]

For each axis k of arr_i, the corresponding intermediate dim is
inner_map.input_dimensions[k] = d. We need to evaluate arr_i over the
product of `outer.output[d]` for each such d.

All-constant outer: collapse to a single ConstantMap.

Single 1-D inner array, single outer output: evaluate arr_i along the
one outer output's parameterization.
"""
arr_i = inner_map.index_array
offset_i = inner_map.offset
stride_i = inner_map.stride
in_dims_i = inner_map.input_dimensions

# All-constant outer: arr_i is evaluated at a single fixed point.
if all(isinstance(m, ConstantMap) for m in outer.output):
idx = tuple(outer.output[d].offset for d in in_dims_i)
value = int(arr_i[idx])
return ConstantMap(offset=offset_i + stride_i * value)

# 1-D inner array, single referenced outer output.
if len(in_dims_i) == 1:
dim_i = in_dims_i[0]
outer_map = outer.output[dim_i]

if isinstance(outer_map, DimensionMap):
# Evaluate arr_i at the outer DimensionMap's range.
input_d = outer_map.input_dimension
input_lo = outer.domain.inclusive_min[input_d]
input_hi = outer.domain.exclusive_max[input_d]
user_indices = np.arange(input_lo, input_hi, dtype=np.intp)
intermediate_vals = outer_map.offset + outer_map.stride * user_indices
new_arr = arr_i[intermediate_vals]
return ArrayMap(
index_array=new_arr,
input_dimensions=(input_d,),
offset=offset_i,
stride=stride_i,
)

if isinstance(outer_map, ArrayMap):
# Evaluate arr_i at outer's array values; new array inherits outer's
# parameterization.
intermediate_vals = outer_map.offset + outer_map.stride * outer_map.index_array
new_arr = arr_i[intermediate_vals]
return ArrayMap(
index_array=new_arr,
input_dimensions=outer_map.input_dimensions,
offset=offset_i,
stride=stride_i,
)

# General multi-dim case: not yet implemented.
raise NotImplementedError(
"Composing a multi-dimensional inner array map with non-constant outer maps "
"is not yet supported."
)
Loading
Loading