Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3744.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for numpy masked arrays (numpy.ma.MaskedArray) in zarr.array(). When a masked array is provided, it is automatically converted to a filled array, with a warning that the mask is not preserved. Users who need to preserve mask information should use separate arrays or structured dtypes for storing both data and mask information.
6 changes: 5 additions & 1 deletion examples/custom_dtype/custom_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> ml_dtypes.
def test_custom_dtype(tmp_path: Path, zarr_format: ZarrFormat) -> None:
# create array and write values
z_w = zarr.create_array(
store=tmp_path, shape=(4,), dtype="int2", zarr_format=zarr_format, compressors=None
store=tmp_path,
shape=(4,),
dtype="int2",
zarr_format=zarr_format,
compressors=None,
)
z_w[:] = [-1, -2, 0, 1]

Expand Down
8 changes: 7 additions & 1 deletion src/zarr/abc/buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from zarr.core.buffer.core import ArrayLike, Buffer, BufferPrototype, NDArrayLike, NDBuffer
from zarr.core.buffer.core import (
ArrayLike,
Buffer,
BufferPrototype,
NDArrayLike,
NDBuffer,
)

__all__ = [
"ArrayLike",
Expand Down
9 changes: 8 additions & 1 deletion src/zarr/abc/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

from abc import abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable
from typing import (
TYPE_CHECKING,
Generic,
Protocol,
TypeGuard,
TypeVar,
runtime_checkable,
)

from typing_extensions import ReadOnly, TypedDict

Expand Down
58 changes: 49 additions & 9 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,26 @@ def _check_writable(self) -> None:
if self.read_only:
raise ValueError("store was opened in read-only mode and does not support writing")

def _ensure_buffer(self, value: Buffer | bytes) -> Buffer:
"""Convert bytes to Buffer if needed.

Parameters
----------
value : Buffer or bytes
The value to ensure is a Buffer.

Returns
-------
Buffer
The input value if it's already a Buffer, or a new Buffer created from bytes.
"""
# avoid circular import
from zarr.core.buffer import Buffer

if isinstance(value, bytes):
return Buffer.from_bytes(value)
return value

@abstractmethod
def __eq__(self, value: object) -> bool:
"""Equality comparison."""
Expand Down Expand Up @@ -219,7 +239,11 @@ async def get(
...

async def _get_bytes(
self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None
self,
key: str,
*,
prototype: BufferPrototype,
byte_range: ByteRequest | None = None,
) -> bytes:
"""
Retrieve raw bytes from the store asynchronously.
Expand Down Expand Up @@ -267,7 +291,11 @@ async def _get_bytes(
return buffer.to_bytes()

def _get_bytes_sync(
self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None
self,
key: str = "",
*,
prototype: BufferPrototype,
byte_range: ByteRequest | None = None,
) -> bytes:
"""
Retrieve raw bytes from the store synchronously.
Expand Down Expand Up @@ -318,7 +346,11 @@ def _get_bytes_sync(
return sync(self._get_bytes(key, prototype=prototype, byte_range=byte_range))

async def _get_json(
self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None
self,
key: str,
*,
prototype: BufferPrototype,
byte_range: ByteRequest | None = None,
) -> Any:
"""
Retrieve and parse JSON data from the store asynchronously.
Expand Down Expand Up @@ -368,7 +400,11 @@ async def _get_json(
return json.loads(await self._get_bytes(key, prototype=prototype, byte_range=byte_range))

def _get_json_sync(
self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None
self,
key: str = "",
*,
prototype: BufferPrototype,
byte_range: ByteRequest | None = None,
) -> Any:
"""
Retrieve and parse JSON data from the store synchronously.
Expand Down Expand Up @@ -465,24 +501,28 @@ def supports_writes(self) -> bool:
...

@abstractmethod
async def set(self, key: str, value: Buffer) -> None:
async def set(self, key: str, value: Buffer | bytes) -> None:
"""Store a (key, value) pair.

Parameters
----------
key : str
value : Buffer
value : Buffer or bytes
The value to store. If bytes are provided, they will be converted
to a Buffer internally.
"""
...

async def set_if_not_exists(self, key: str, value: Buffer) -> None:
async def set_if_not_exists(self, key: str, value: Buffer | bytes) -> None:
"""
Store a key to ``value`` if the key is not already present.

Parameters
----------
key : str
value : Buffer
value : Buffer or bytes
The value to store. If bytes are provided, they will be converted
to a Buffer internally.
"""
# Note for implementers: the default implementation provided here
# is not safe for concurrent writers. There's a race condition between
Expand All @@ -491,7 +531,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None:
if not await self.exists(key):
await self.set(key, value)

async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
async def _set_many(self, values: Iterable[tuple[str, Buffer | bytes]]) -> None:
"""
Insert multiple (key, value) pairs into storage.
"""
Expand Down
22 changes: 19 additions & 3 deletions src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def _infer_overwrite(mode: AccessModeLiteral) -> bool:
return mode in _OVERWRITE_MODES


def _get_shape_chunks(a: ArrayLike | Any) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]:
def _get_shape_chunks(
a: ArrayLike | Any,
) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]:
"""Helper function to get the shape and chunks from an array-like object"""
shape = None
chunks = None
Expand Down Expand Up @@ -179,7 +181,9 @@ def _handle_zarr_version_or_format(
)
if zarr_version is not None:
warnings.warn(
"zarr_version is deprecated, use zarr_format", ZarrDeprecationWarning, stacklevel=2
"zarr_version is deprecated, use zarr_format",
ZarrDeprecationWarning,
stacklevel=2,
)
return zarr_version
return zarr_format
Expand Down Expand Up @@ -386,7 +390,9 @@ async def open(
is_v3_array = zarr_format == 3 and _metadata_dict.get("node_type") == "array"
if is_v3_array or zarr_format == 2:
return AsyncArray(
store_path=store_path, metadata=_metadata_dict, config=kwargs.get("config")
store_path=store_path,
metadata=_metadata_dict,
config=kwargs.get("config"),
)
except (AssertionError, FileNotFoundError, NodeTypeValidationError):
pass
Expand Down Expand Up @@ -621,6 +627,16 @@ async def array(data: npt.ArrayLike | AnyArray, **kwargs: Any) -> AnyAsyncArray:
if isinstance(data, Array):
return await from_array(data=data, **kwargs)

# Handle masked arrays by converting to filled array
if isinstance(data, np.ma.MaskedArray):
warnings.warn(
"Masked arrays are not fully supported in Zarr. The mask will not be preserved. "
"Consider using zarr's structured dtype or a separate array for the mask if you need to preserve it.",
UserWarning,
stacklevel=2,
)
data = cast(np.ndarray, data.filled())

# ensure data is array-like
if not hasattr(data, "shape") or not hasattr(data, "dtype"):
data = np.asanyarray(data)
Expand Down
7 changes: 6 additions & 1 deletion src/zarr/api/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,12 @@ def save(
"""
return sync(
async_api.save(
store, *args, zarr_version=zarr_version, zarr_format=zarr_format, path=path, **kwargs
store,
*args,
zarr_version=zarr_version,
zarr_format=zarr_format,
path=path,
**kwargs,
)
)

Expand Down
4 changes: 3 additions & 1 deletion src/zarr/codecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@
register_codec("numcodecs.fletcher32", Fletcher32, qualname="zarr.codecs.numcodecs.Fletcher32")
register_codec("numcodecs.gzip", GZip, qualname="zarr.codecs.numcodecs.GZip")
register_codec(
"numcodecs.jenkins_lookup3", JenkinsLookup3, qualname="zarr.codecs.numcodecs.JenkinsLookup3"
"numcodecs.jenkins_lookup3",
JenkinsLookup3,
qualname="zarr.codecs.numcodecs.JenkinsLookup3",
)
register_codec("numcodecs.pcodec", PCodec, qualname="zarr.codecs.numcodecs.PCodec")
register_codec("numcodecs.packbits", PackBits, qualname="zarr.codecs.numcodecs.PackBits")
Expand Down
7 changes: 6 additions & 1 deletion src/zarr/codecs/blosc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer.cpu import as_numpy_array_wrapper
from zarr.core.common import JSON, NamedRequiredConfig, parse_enum, parse_named_configuration
from zarr.core.common import (
JSON,
NamedRequiredConfig,
parse_enum,
parse_named_configuration,
)
from zarr.core.dtype.common import HasItemSize

if TYPE_CHECKING:
Expand Down
3 changes: 2 additions & 1 deletion src/zarr/codecs/crc32c_.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def _encode_sync(
data = chunk_bytes.as_numpy_array()
# Calculate the checksum and "cast" it to a numpy array
checksum = np.array(
[google_crc32c.value(cast("typing_extensions.Buffer", data))], dtype=np.uint32
[google_crc32c.value(cast("typing_extensions.Buffer", data))],
dtype=np.uint32,
)
# Append the checksum (as bytes) to the data
return chunk_spec.prototype.buffer.from_array_like(np.append(data, checksum.view("B")))
Expand Down
25 changes: 18 additions & 7 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ async def from_bytes(

@classmethod
def create_empty(
cls, chunks_per_shard: tuple[int, ...], buffer_prototype: BufferPrototype | None = None
cls,
chunks_per_shard: tuple[int, ...],
buffer_prototype: BufferPrototype | None = None,
) -> _ShardReader:
if buffer_prototype is None:
buffer_prototype = default_buffer_prototype()
Expand Down Expand Up @@ -297,7 +299,9 @@ def to_dict_vectorized(

@dataclass(frozen=True)
class ShardingCodec(
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
ArrayBytesCodec,
ArrayBytesCodecPartialDecodeMixin,
ArrayBytesCodecPartialEncodeMixin,
):
"""Sharding codec"""

Expand All @@ -312,7 +316,7 @@ def __init__(
chunk_shape: ShapeLike,
codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),),
index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()),
index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end,
index_location: (ShardingCodecIndexLocation | str) = ShardingCodecIndexLocation.end,
) -> None:
chunk_shape_parsed = parse_shapelike(chunk_shape)
codecs_parsed = parse_codecs(codecs)
Expand Down Expand Up @@ -585,7 +589,9 @@ async def _encode_partial_single(

indexer = list(
get_indexer(
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
selection,
shape=shard_shape,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
)
)

Expand Down Expand Up @@ -701,7 +707,8 @@ def _shard_index_size(self, chunks_per_shard: tuple[int, ...]) -> int:
get_pipeline_class()
.from_codecs(self.index_codecs)
.compute_encoded_size(
16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard)
16 * product(chunks_per_shard),
self._get_index_chunk_spec(chunks_per_shard),
)
)

Expand Down Expand Up @@ -746,7 +753,8 @@ async def _load_shard_index_maybe(
)
else:
index_bytes = await byte_getter.get(
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size)
prototype=numpy_buffer_prototype(),
byte_range=SuffixByteRequest(shard_index_size),
)
if index_bytes is not None:
return await self._decode_shard_index(index_bytes, chunks_per_shard)
Expand All @@ -760,7 +768,10 @@ async def _load_shard_index(
) or _ShardIndex.create_empty(chunks_per_shard)

async def _load_full_shard_maybe(
self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: tuple[int, ...]
self,
byte_getter: ByteGetter,
prototype: BufferPrototype,
chunks_per_shard: tuple[int, ...],
) -> _ShardReader | None:
shard_bytes = await byte_getter.get(prototype=prototype)

Expand Down
5 changes: 4 additions & 1 deletion src/zarr/codecs/zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
return cls(**configuration_parsed) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
return {"name": "zstd", "configuration": {"level": self.level, "checksum": self.checksum}}
return {
"name": "zstd",
"configuration": {"level": self.level, "checksum": self.checksum},
}

@cached_property
def _zstd_codec(self) -> Zstd:
Expand Down
Loading
Loading