Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 54 additions & 51 deletions tests/test_xarray_adios2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations

import os
from typing import Any

import adios2py
import numpy as np
import pytest
Expand Down Expand Up @@ -70,11 +67,14 @@ def test_filename_4(tmp_path):
return filename


def _open_dataset(filename: os.PathLike[Any], *, decode: bool = False) -> xr.Dataset:
ds = xr.open_dataset(filename)
if decode:
ds = _decode_dataset(ds)
return ds
@pytest.fixture
def ds_pfd_raw() -> xr.Dataset:
return xr.open_dataset(pscpy.sample_dir / "pfd.000000400.bp")


@pytest.fixture
def ds_pfd_moments_raw() -> xr.Dataset:
return xr.open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp")


def _decode_dataset(ds: xr.Dataset) -> xr.Dataset:
Expand All @@ -86,77 +86,80 @@ def _decode_dataset(ds: xr.Dataset) -> xr.Dataset:
)


def test_open_dataset():
ds_decoded = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp", decode=True)
assert "jx_ec" in ds_decoded
assert ds_decoded.coords.keys() == set({"x", "y", "z"})
assert ds_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408
@pytest.fixture
def ds_pfd_decoded(ds_pfd_raw) -> xr.Dataset:
return _decode_dataset(ds_pfd_raw)


@pytest.fixture
def ds_pfd_moments_decoded(ds_pfd_moments_raw) -> xr.Dataset:
return _decode_dataset(ds_pfd_moments_raw)


def test_open_dataset(ds_pfd_decoded):
assert "jx_ec" in ds_pfd_decoded
assert ds_pfd_decoded.coords.keys() == set({"x", "y", "z"})
assert ds_pfd_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408
assert np.allclose(
ds_decoded.jx_ec.z.data, np.linspace(-25.6, 25.6, 512, endpoint=False).data
ds_pfd_decoded.jx_ec.z.data, np.linspace(-25.6, 25.6, 512, endpoint=False).data
)


def test_component():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
assert np.all(ds_raw.jeh.isel(dim_1_9=0).data == ds_decoded.jx_ec.data)
def test_component(ds_pfd_raw, ds_pfd_decoded):
assert np.all(ds_pfd_raw.jeh.isel(dim_1_9=0).data == ds_pfd_decoded.jx_ec.data)


def test_selection():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
assert np.all(
ds_raw.jeh.isel(dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)).data
== ds_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data
)
def test_selection(ds_pfd_raw, ds_pfd_decoded):
data_raw = ds_pfd_raw.jeh.isel(
dim_1_9=0, dim_3_128=slice(0, 10), dim_2_512=slice(0, 40)
).data
data_decoded = ds_pfd_decoded.jx_ec.isel(y=slice(0, 10), z=slice(0, 40)).data
assert np.all(data_raw == data_decoded)


def test_nbytes():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
assert ds_decoded.nbytes == ds_decoded.nbytes
def _get_nbytes(ds: xr.Dataset) -> int:
return sum(arr.nbytes for arr in ds.data_vars.values())


def test_missing_length():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
def test_nbytes(ds_pfd_raw, ds_pfd_decoded):
assert _get_nbytes(ds_pfd_raw) == _get_nbytes(ds_pfd_decoded)


def test_missing_length(ds_pfd_raw):
with pytest.raises(ValueError, match=r".*length.*"):
pscpy.decode_psc(
ds_raw,
ds_pfd_raw,
species_names=["e", "i"],
corner=[0, -6.4, -25.6],
)


def test_missing_corner():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
def test_missing_corner(ds_pfd_raw):
with pytest.raises(ValueError, match=r".*corner.*"):
pscpy.decode_psc(
ds_raw,
ds_pfd_raw,
species_names=["e", "i"],
length=[1, 12.8, 51.2],
)


def test_computed():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
ds_raw = ds_raw.assign(jx=ds_raw.jeh.isel(dim_1_9=0))
assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data)
def test_computed(ds_pfd_raw, ds_pfd_decoded):
ds_pfd_raw = ds_pfd_raw.assign(jx=ds_pfd_raw.jeh.isel(dim_1_9=0))
assert np.all(ds_pfd_raw.jx.data == ds_pfd_decoded.jx_ec.data)


def test_computed_via_lambda():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
ds_raw = ds_raw.assign(jx=lambda ds: ds.jeh.isel(dim_1_9=0))
assert np.all(ds_raw.jx.data == ds_decoded.jx_ec.data)
def test_computed_via_lambda(ds_pfd_raw, ds_pfd_decoded):
ds_pfd_raw = ds_pfd_raw.assign(jx=lambda ds: ds.jeh.isel(dim_1_9=0))
assert np.all(ds_pfd_raw.jx.data == ds_pfd_decoded.jx_ec.data)


def test_pfd_moments():
ds_raw = _open_dataset(pscpy.sample_dir / "pfd_moments.000000400.bp")
ds_decoded = _decode_dataset(ds_raw)
assert "all_1st" in ds_raw
assert "rho_i" in ds_decoded
assert np.all(ds_decoded.rho_i.data == ds_raw.all_1st.isel(dim_1_26=13).data)
def test_pfd_moments(ds_pfd_moments_raw, ds_pfd_moments_decoded):
assert "all_1st" in ds_pfd_moments_raw
assert "rho_i" in ds_pfd_moments_decoded
assert np.all(
ds_pfd_moments_decoded.rho_i.data
== ds_pfd_moments_raw.all_1st.isel(dim_1_26=13).data
)


def test_open_dataset_steps(test_filename):
Expand Down