Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion packages/data/tests/transform/test_derive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import pytest
import math

from pyearthtools.data.transforms.derive import evaluate
from numpy import nan, isnan
from pyearthtools.data.transforms.derive import evaluate, EquationException


@pytest.mark.parametrize(
Expand Down Expand Up @@ -46,6 +47,20 @@ def test_evaluate_only_eq(eq, result):
assert evaluate(eq) == float(result)


@pytest.mark.parametrize(
"eq",
[
("1 + (2"),
("1 + 2)"),
("1 + ((2 + 3)"),
("1 + (2 + 3))"),
],
)
def test_evaluate_mismatched_brackets(eq):
with pytest.raises(EquationException):
evaluate(eq)


@pytest.mark.parametrize(
"eq, result",
[
Expand All @@ -55,3 +70,20 @@ def test_evaluate_only_eq(eq, result):
)
def test_constants(eq, result):
assert evaluate(eq) == float(result)


@pytest.mark.parametrize(
"eq, result",
[
("2 not_nan 3", 2.0),
("nan not_nan 3", 3.0),
("nan not_nan 3 not_nan 4", 3.0),
("nan not_nan nan not_nan 4", 4.0),
],
)
def test_evaluate_only_not_nan(eq, result):
assert evaluate(eq) == result


def test_evaluate_only_not_nan_all_nan():
assert isnan(evaluate("nan not_nan nan"))
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ def __init__(self) -> None:
self.record_initialisation()

def filter(self, sample: np.ndarray):
"""Check if any of the sample is nan
"""Reject the sample if any value is nan

Args:
sample (np.ndarray):
Sample to check
Returns:
(bool):
If sample contains nan's
Raises:
(PipelineFilterException):
If sample contains one or more nan value
"""
if not bool(np.array(list(np.isnan(sample))).any()):
if bool(np.array(list(np.isnan(sample))).any()):
raise PipelineFilterException(sample, "Data contained nan's.")


Expand All @@ -76,16 +76,16 @@ def __init__(self) -> None:
self.record_initialisation()

def filter(self, sample: np.ndarray):
"""Check if all of the sample is nan
"""Reject the sample if all of its values are nan

Args:
sample (np.ndarray):
Sample to check
Returns:
(bool):
If sample contains nan's
Raises:
(PipelineFilterException):
If sample contains only nan values
"""
if not bool(np.array(list(np.isnan(sample))).all()):
if bool(np.array(list(np.isnan(sample))).all()):
raise PipelineFilterException(sample, "Data contained all nan's.")


Expand Down
96 changes: 96 additions & 0 deletions packages/pipeline/tests/operations/numpy/test_numpy_augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2025.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pyearthtools.pipeline.operations.numpy import augment

import numpy as np
import pytest


@pytest.mark.parametrize(
# The result depends on the random seed. This one has been manually
# checked to produce a certain number of rotations the first time.
"seed, rotations",
[
(42, 0),
(1, 1),
(4, 2),
(2, 3),
],
)
def test_Rotate(seed, rotations):

original = np.array([[1, 2], [4, 3]])

match rotations:
case 0:
expected = np.array([[1, 2], [4, 3]])
case 1:
expected = np.array([[4, 1], [3, 2]])
case 2:
expected = np.array([[3, 4], [2, 1]])
case 3:
expected = np.array([[2, 3], [1, 4]])

rotate = augment.Rotate(seed=seed, axis=(1, 0))

result = rotate.apply_func(original)
assert (result == expected).all()


def test_Rotate_axis_must_be_tuple():
with pytest.raises(TypeError):
augment.Rotate(axis=0)


@pytest.mark.parametrize(
"seed, should_flip",
[
(0, True),
(1, False),
],
)
def test_Flip(seed, should_flip):

original = np.array([[1, 2], [4, 3]])

flipped = np.array([[3, 4], [2, 1]])

# The result depends on the random seed. This one has been manually checked
# to produce a single rotation the first time.
expected = flipped if should_flip else original
flip = augment.Flip(seed=seed, axis=(1, 0))

result = flip.apply_func(original)
assert (result == expected).all()


@pytest.mark.parametrize(
"seed, should_flip",
[
(0, True),
(1, False),
],
)
def test_FlipAndRotate(seed, should_flip):

original = np.array([[1, 2], [4, 3]])

flip_and_rotate = augment.FlipAndRotate()

result = flip_and_rotate.apply_func(original)
# Don't worry about the number of flips and rotations, just check the
# shape and type returned
assert isinstance(result, np.ndarray)
assert result.shape == (2, 2)
35 changes: 22 additions & 13 deletions packages/pipeline/tests/operations/numpy/test_numpy_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import numpy as np
import xarray as xr
import dask.array as da


def test_ToXarray_with_DataArray():

coords = {"x": list(range(5)), "y": list(range(5))}
data = np.ones((5, 5))
data = np.random.randn(5, 5)
sample = xr.DataArray(coords=coords, data=data)

tox = conversion.ToXarray.like(sample)
Expand All @@ -36,37 +37,45 @@ def test_ToXarray_with_DataArray():
def test_ToXarray_with_Dataset():

coords = {"x": list(range(5)), "y": list(range(5))}
data = np.ones((5, 5))
data1 = np.ones((1, 5, 5))
sample_da = xr.DataArray(coords=coords, data=data)
data_3d = np.random.randn(1, 5, 5)
data_2d = data_3d[0]
sample_da = xr.DataArray(coords=coords, data=data_2d)
sample_ds = xr.Dataset(coords=coords, data_vars={"z": sample_da})

tox = conversion.ToXarray.like(sample_ds)
result = tox.apply_func(data1)
result = tox.apply_func(data_3d)

assert (result == sample_ds).all()

as_numpy = tox.undo_func(sample_ds)
assert (as_numpy == data1).all()
assert (as_numpy == data_3d).all()


def test_drop_coords():

coords = {"x": list(range(5)), "y": list(range(5))}
data = np.ones((5, 5))
_data1 = np.ones((1, 5, 5))
sample_da = xr.DataArray(coords=coords, data=data)

data_3d = np.random.randn(1, 5, 5)
data_2d = data_3d[0]
sample_da = xr.DataArray(coords=coords, data=data_2d)
sample_ds = xr.Dataset(coords=coords, data_vars={"z": sample_da})

tox = conversion.ToXarray.like(sample_ds, drop_coords=["x"])
assert tox is not None
result = tox.apply_func(data_3d)
assert (result == sample_ds).all()

as_numpy = tox.undo_func(sample_ds)
assert (as_numpy == data_3d).all()


def test_ToDask():

data = np.ones((5, 5))
data = np.random.randn(5, 5)
expected = da.from_array(data)

tod = conversion.ToDask()
da = tod.apply_func(data)
orig = tod.undo_func(da)
result = tod.apply_func(data)
da.assert_eq(result, expected)

orig = tod.undo_func(result)
assert (orig == data).all()
57 changes: 57 additions & 0 deletions packages/pipeline/tests/operations/numpy/test_numpy_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2025.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pyearthtools.pipeline.operations.numpy import filters
from pyearthtools.pipeline.exceptions import PipelineFilterException

import numpy as np
import pytest


def test_DropAnyNan_false():

original = np.array([[1, 2], [4, 3]])

drop = filters.DropAnyNan()
# No return value, just check no exception is raised
drop.filter(original)


def test_DropAnyNan_true():

original = np.array([[1, 2], [4, np.nan]])

drop = filters.DropAnyNan()

with pytest.raises(PipelineFilterException):
result = drop.filter(original)


def test_DropAllNan_false():

original = np.array([[1, 2], [np.nan, 3]])

drop = filters.DropAllNan()
# No return value, just check no exception is raised
drop.filter(original)


def test_DropAllNan_true():

original = np.array([[np.nan, np.nan], [np.nan, np.nan]])

drop = filters.DropAllNan()

with pytest.raises(PipelineFilterException):
result = drop.filter(original)
Loading