Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 88 additions & 54 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,6 @@ def _getitem(
"""

obj = accessor._obj
all_bounds = obj.cf.bounds if isinstance(obj, Dataset) else {}
kind = str(type(obj).__name__)
scalar_key = isinstance(key, Hashable)

Expand All @@ -1096,6 +1095,7 @@ def drop_bounds(names):
# with a scalar key. Hopefully these will soon get decoded to IntervalIndex
# and we can move on...
if not isinstance(obj, DataArray) and scalar_key:
all_bounds = obj.cf.bounds
bounds = set()
for name in names:
bounds.update(all_bounds.get(name, []))
Expand Down Expand Up @@ -1127,60 +1127,94 @@ def check_results(names, key):

custom_criteria = ChainMap(*OPTIONS["custom_criteria"])

varnames: list[Hashable] = []
coords: list[Hashable] = []
successful = dict.fromkeys(key_iter, False)
for k in key_iter:
if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES:
names = _get_all(obj, k)
names = drop_bounds(names)
check_results(names, k)
successful[k] = bool(names)
coords.extend(names)
elif "measures" not in skip and k in measures:
measure = _get_all(obj, k)
check_results(measure, k)
successful[k] = bool(measure)
if measure:
varnames.extend(measure)
elif "grid_mapping_names" not in skip and k in grid_mapping_names:
grid_mapping = _get_all(obj, k)
check_results(grid_mapping, k)
successful[k] = bool(grid_mapping)
if grid_mapping:
varnames.extend(grid_mapping)
elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES):
geometries = _get_all(obj, k)
if geometries and k in _GEOMETRY_TYPES:
new = itertools.chain(
_parse_related_geometry_vars(
ChainMap(obj[g].attrs, obj[g].encoding)
# Fast path: every key is a direct variable name and matches no CF
# special key, so the dispatch loop below would simply fall through to
# the ``successful[k] is False`` clause that treats ``k`` as a variable
# name. Skip the per-key ``_get_all`` mapper fan-out entirely. Defer to
# the slow path if any *other* variable advertises a key as its
# standard_name -- the slow path returns those vars rather than the
# direct lookup.
reserved: set[Hashable] = set(_AXIS_NAMES).union(
_COORD_NAMES,
_GEOMETRY_TYPES,
("geometry",),
measures,
grid_mapping_names,
custom_criteria,
cf_role_criteria,
)
fast_path = (
isinstance(obj, Dataset)
and not skip
and all(
k in obj._variables
and k not in reserved
and accessor.standard_names.get(k, [k]) == [k]
for k in key_iter
)
)

varnames: list[Hashable]
coords: list[Hashable]
if fast_path:
varnames = list(key_iter)
coords = []
successful = dict.fromkeys(key_iter, True)
else:
varnames = []
coords = []
successful = dict.fromkeys(key_iter, False)
for k in key_iter:
if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES:
names = _get_all(obj, k)
names = drop_bounds(names)
check_results(names, k)
successful[k] = bool(names)
coords.extend(names)
elif "measures" not in skip and k in measures:
measure = _get_all(obj, k)
check_results(measure, k)
successful[k] = bool(measure)
if measure:
varnames.extend(measure)
elif "grid_mapping_names" not in skip and k in grid_mapping_names:
grid_mapping = _get_all(obj, k)
check_results(grid_mapping, k)
successful[k] = bool(grid_mapping)
if grid_mapping:
varnames.extend(grid_mapping)
elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES):
geometries = _get_all(obj, k)
if geometries and k in _GEOMETRY_TYPES:
new = itertools.chain(
_parse_related_geometry_vars(
ChainMap(obj[g].attrs, obj[g].encoding)
)
for g in geometries
)
for g in geometries
)
geometries.extend(*new)
if len(geometries) > 1 and scalar_key:
raise ValueError(
f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead."
)
successful[k] = bool(geometries)
if geometries:
varnames.extend(geometries)
elif k in custom_criteria or k in cf_role_criteria:
names = _get_all(obj, k)
check_results(names, k)
successful[k] = bool(names)
varnames.extend(names)
else:
stdnames = set(_get_with_standard_name(obj, k))
objcoords = set(obj.coords)
stdnames = drop_bounds(stdnames)
if "coords" in skip:
stdnames -= objcoords
check_results(stdnames, k)
successful[k] = bool(stdnames)
varnames.extend(stdnames - objcoords)
coords.extend(stdnames & objcoords)
geometries.extend(*new)
if len(geometries) > 1 and scalar_key:
raise ValueError(
f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead."
)
successful[k] = bool(geometries)
if geometries:
varnames.extend(geometries)
elif k in custom_criteria or k in cf_role_criteria:
names = _get_all(obj, k)
check_results(names, k)
successful[k] = bool(names)
varnames.extend(names)
else:
stdnames = set(_get_with_standard_name(obj, k))
objcoords = set(obj.coords)
stdnames = drop_bounds(stdnames)
if "coords" in skip:
stdnames -= objcoords
check_results(stdnames, k)
successful[k] = bool(stdnames)
varnames.extend(stdnames - objcoords)
coords.extend(stdnames & objcoords)

# these are not special names but could be variable names in underlying object
# we allow this so that we can return variables with appropriate CF auxiliary variables
Expand Down
19 changes: 14 additions & 5 deletions cf_xarray/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import functools
import inspect
import os
import warnings
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from typing import Any
from xml.etree import ElementTree

import numpy as np
from xarray import DataArray
from xarray.core.utils import Frozen

try:
import cftime
Expand Down Expand Up @@ -64,25 +66,32 @@ def _is_datetime_like(da: DataArray) -> bool:
return False


def parse_cell_methods_attr(attr: str) -> dict[str, str]:
@functools.lru_cache(maxsize=256)
def parse_cell_methods_attr(attr: str) -> Mapping[str, str]:
"""
Parse cell_methods attributes (format is 'measure: name').

The result is memoized per attribute string and returned as a read-only
``Frozen`` mapping. ``cell_measures`` strings are typically shared across
many variables in a dataset (e.g., every CMIP variable carries the same
``area: areacella volume: ...``), so parsing once and reusing avoids
repeated string splitting in the hot path of ``_get_measure``.

Parameters
----------
attr : str
String to parse

Returns
-------
Dictionary mapping measure to name
Read-only mapping from measure to name.
"""
strings = [s for scolons in attr.split(":") for s in scolons.split()]
if len(strings) % 2 != 0:
raise ValueError(f"attrs['cell_measures'] = {attr!r} is malformed.")

return dict(
zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)], strict=False)
return Frozen(
dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)], strict=False))
)


Expand Down
Loading