Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ class AlignDataVariableDimensionsToDatasetCoords(Operation):
"""

def apply_func(self, data: xr.Dataset) -> xr.Dataset:
dataset_ordering = list(data.coords)
# use coords.dim for when coordinates don't have the same name as dimensions
dataset_ordering = list(data.coords.dims)

data = data.transpose(*dataset_ordering)
return data

def undo_func(self, data: xr.Dataset) -> xr.Dataset:
# TODO: Record all the original orderings and transpose them back, I guess

return data
raise NotImplementedError("Don't yet know how to undo data variable alignment.")


class Sort(Operation):
Expand Down
51 changes: 50 additions & 1 deletion packages/pipeline/tests/operations/xarray/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import xarray as xr

from pyearthtools.pipeline.operations.xarray import _sort as sort
from pyearthtools.pipeline.operations.xarray import _sort as sort, AlignDataVariableDimensionsToDatasetCoords

SIMPLE_DA1 = xr.DataArray(
[
Expand All @@ -37,6 +37,55 @@
SIMPLE_DS2 = xr.Dataset({"Humidity": SIMPLE_DA1, "Temperature": SIMPLE_DA1, "WombatsPerKm2": SIMPLE_DA1})


def test_align():
"""Tests that the dataset dimension alignment operation works."""
align_op = AlignDataVariableDimensionsToDatasetCoords()

# create dataset with arrays that are not consistently ordered
ds = xr.Dataset(
{
"Temperature": SIMPLE_DA1.transpose("lat", "height", "lon"),
"Humidity": SIMPLE_DA1,
"WombatsPerKm2": SIMPLE_DA1.transpose("lon", "height", "lat"),
}
)

# check that dataset dims are indeed unaligned
assert ds["Temperature"].dims != ds["Humidity"].dims
assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims

# apply aligner to dataset and check that dataset dims now align
ds_aligned = align_op.apply_func(ds)
assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims
assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims

## Test that alignment works even when coordinate names don't match dims
da_with_named_coords = xr.DataArray(
SIMPLE_DA1.data,
coords={"h": ("height", [10, 20]), "x": ("lat", [0, 1, 2]), "y": ("lon", [5, 6, 7])},
dims=["height", "lat", "lon"],
)
ds = xr.Dataset(
{
"Temperature": da_with_named_coords.transpose("lat", "height", "lon"),
"Humidity": da_with_named_coords,
"WombatsPerKm2": da_with_named_coords.transpose("lon", "height", "lat"),
}
)
# check that dataset dims are indeed unaligned
assert ds["Temperature"].dims != ds["Humidity"].dims
assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims

# apply aligner to dataset and check that dataset dims now align
ds_aligned = align_op.apply_func(ds)
assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims
assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims

# placeholder test for undo method
with pytest.raises(NotImplementedError):
align_op.undo_func(ds)


def test_Sort():

s = sort.Sort()
Expand Down