Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def open_file(file: str | tuple | dict):
data = open_files(file)
if isinstance(data, (xr.Dataset, xr.DataArray)):
data = get_default_transforms()(data)
data = pyearthtools.data.transforms.coordinates.drop("time", ignore_missing=True)(data)
data = pyearthtools.data.transforms.coordinates.Drop("time", ignore_missing=True)(data)
return data


Expand All @@ -61,7 +61,7 @@ def under_func(*args, **kwargs):
return under_func


class normaliser:
class Normaliser:
def __init__(
self,
index: pyearthtools.data.AdvancedTimeIndex,
Expand Down Expand Up @@ -214,7 +214,7 @@ def get_aggregation(
# )

aggregated_data = get_and_print(
lambda: pyearthtools.data.transforms.aggregation.over(method, dims)(
lambda: pyearthtools.data.transforms.aggregation.over(method=method, dimension=dims)(
self.index.series(
**retrieval_args,
transforms=transforms,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@

xr.set_options(keep_attrs=True)

from pyearthtools.data.transforms.normalisation.default import normaliser, open_file
from pyearthtools.data.transforms.normalisation.default import Normaliser, open_file
from pyearthtools.data.transforms.transform import FunctionTransform, Transform


class Normalise(normaliser):
class Normalise(Normaliser):
"""
Normalise incoming data.

Either call this class, or get attribute for specific normalisation strategy
"""

@functools.wraps(normaliser.__init__)
@functools.wraps(Normaliser.__init__)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
import numpy as np
import xarray as xr

from pyearthtools.data.transforms.normalisation.default import normaliser, open_file
from pyearthtools.data.transforms.normalisation.default import Normaliser, open_file
from pyearthtools.data.transforms.transform import FunctionTransform, Transform

xr.set_options(keep_attrs=True)


class Unnormalise(normaliser):
class Unnormalise(Normaliser):
"""Unnormalise Incoming Data"""

@functools.wraps(normaliser)
@functools.wraps(Normaliser)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions packages/data/src/pyearthtools/data/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def apply(self, dataset: XR_TYPES | tuple[XR_TYPES] | list[XR_TYPES] | dict[str,
def __call__(self, dataset: XR_TYPES | tuple[XR_TYPES] | list[XR_TYPES] | dict[str, XR_TYPES]) -> XR_TYPES | Any:

# Do not try to transform empty datasets
if not dataset:
return dataset
if dataset is None:
return None

for transform in self._transforms:
dataset = transform(dataset)
Expand Down
86 changes: 86 additions & 0 deletions packages/data/tests/data/transform/normalisation/test_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pyearthtools.data.transforms.normalisation
from pyearthtools.data.transforms.normalisation import default
from pyearthtools.data.time import Petdt
import pyearthtools.data.indexes
import xarray as xr
import numpy as np
import pytest

sample_da = xr.DataArray(coords={"latitude": [1,2,3,4],
"longitude": [1,2,3],
"time": ["2023-02"]
},
data=np.ones((4,3,1)))

sample_ds = xr.Dataset(coords={"latitude": [1,2,3,4], "longitude": [1,2,3], "time": ["2023-02"]},
data_vars={"temperature": sample_da})


def test_open_file(monkeypatch):

monkeypatch.setattr(pyearthtools.data.transforms.normalisation.default,
'open_files',
lambda x: sample_da)

result = default.open_file("pretend_filename.nc")
assert result is not None


def test_Normaliser(monkeypatch):

monkeypatch.setattr("pyearthtools.data.indexes.AdvancedTimeIndex.__abstractmethods__", set())

data_interval = "day"
ati = pyearthtools.data.indexes.AdvancedTimeIndex(data_interval)
monkeypatch.setattr(ati, "get", lambda x: sample_da)
start = Petdt("2023-02")
end = Petdt("2023-03")

n = default.Normaliser(ati, start, end, "month")
n.check_init_args()

result = n.get_average("temperature")
assert result == 1

r_mean, r_std = n.get_deviation("temperature")
assert r_mean == 1
assert r_std == 0

r_anomaly = n.get_anomaly("temperature")
assert r_anomaly is not None

# FIXME: Need to update the whole test creation to be a time-aware dataset
# r_range = n.get_range("temperature")
# assert r_range["temperature"]["max"] == 1
# assert r_range["temperature"]["min"] == 1

result = n.none
assert result is not None

def test_Normaliser_errors(monkeypatch):

monkeypatch.setattr("pyearthtools.data.indexes.AdvancedTimeIndex.__abstractmethods__", set())

data_interval = "day"
ati = pyearthtools.data.indexes.AdvancedTimeIndex(data_interval)
monkeypatch.setattr(ati, "get", lambda x: sample_da)
start = Petdt("2023-02")
end = Petdt("2023-03")

n = default.Normaliser(ati, start, end, "month")

with pytest.raises(NotImplementedError):
n.function()


not_implemented = [n.log, n.anomaly, n.deviation, n.deviation_spatial, n.range]
for ni in not_implemented:
with pytest.raises(NotImplementedError):
ni()







Loading