Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8d2e45a
rename Squish to Squeeze
Apr 30, 2025
c77c0e5
Merge branch 'rename_squish' into develop
May 4, 2025
4bcfc6e
begin making tests
Jul 23, 2025
be44453
Merge branch 'develop' into test_xarray_reshape, so as to be properly…
Jul 23, 2025
7b178a1
bug found in data.transforms.coordinates.Flatten
Jul 24, 2025
0285681
start testing, improve docstring
Jul 28, 2025
8eab03a
Flatten tested and debugged.
Jul 28, 2025
5279684
test skip_missing
Jul 28, 2025
d3a3541
test skip_missing
Jul 28, 2025
132470d
Final tidying before pull request.
Jul 29, 2025
5891692
Merge branch 'test_data_coordinates' into develop
Aug 5, 2025
acd1870
prepare for merge from develop
Aug 5, 2025
8db1c73
copy Flatten functionality to CoordinateFlatten
Aug 11, 2025
70e746a
first test
Aug 12, 2025
c828cf0
copy old Flatten test over in preparation for reworking
Aug 12, 2025
81dfedc
get ready to update branch
Sep 2, 2025
f19fde0
Merge branch 'develop' into test_xarray_reshape
Sep 2, 2025
a133d9e
start testing, improve docstring, rebase with new changes from upstream
Jul 28, 2025
5dd13a3
Flatten tested and debugged.
Jul 28, 2025
862bec0
test skip_missing
Jul 28, 2025
ca0da82
test skip_missing
Jul 28, 2025
88553b9
Final tidying before pull request.
Jul 29, 2025
33be322
Merge branch 'develop' into test_xarray_reshape
Sep 2, 2025
b4fb81a
test CoordinateExpand and undo
Sep 3, 2025
4baf547
test CoordinateExpand and undo
Sep 3, 2025
f38731e
remove multiple coordinate option from CoordinateFlatten and Coordina…
Sep 3, 2025
3215dd8
remove data.coordinates.Flatten, data.coordinates.Expand, data.coordi…
Sep 4, 2025
399fec0
commit before pulling changes
Sep 4, 2025
d1b1c1e
start testing, improve docstring, rebase with new changes from upstream
Jul 28, 2025
fbb60b8
Flatten tested and debugged.
Jul 28, 2025
35d3d07
test skip_missing
Jul 28, 2025
d6fe731
test skip_missing
Jul 28, 2025
4826a49
Final tidying before pull request.
Jul 29, 2025
ccf6f95
Merge branch 'develop' into test_xarray_reshape
Sep 4, 2025
ad3ce75
Update documentation
Oct 1, 2025
7a5ca9d
Merge from upstream prior to pull request
Oct 2, 2025
662921d
Merge branch 'develop' into test_xarray_reshape
tennlee Oct 16, 2025
e82e05c
Ran code reformatter
tennlee Oct 16, 2025
8f84275
Remove unused imports reported by ruff check
tennlee Oct 16, 2025
559be85
resolve conflicts when pulling
Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/api/data/data_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ The rest of this page contains reference information for the components of the D
| | | - [coordinates.StandardCoordinateNames](data_api.md#pyearthtools.data.transforms.coordinates.StandardCoordinateNames) |
| | | - [coordinates.Select](data_api.md#pyearthtools.data.transforms.coordinates.Select) |
| | | - [coordinates.Drop](data_api.md#pyearthtools.data.transforms.coordinates.Drop) |
| | | - [coordinates.Flatten](data_api.md#pyearthtools.data.transforms.coordinates.Flatten) |
| | | - [coordinates.Expand](data_api.md#pyearthtools.data.transforms.coordinates.Expand) |
| | | - [coordinates.SelectFlatten](data_api.md#pyearthtools.data.transforms.coordinates.SelectFlatten) |
| | | - [coordinates.Assign](data_api.md#pyearthtools.data.transforms.coordinates.Assign) |
| | | - [coordinates.Pad](data_api.md#pyearthtools.data.transforms.coordinates.Pad) |
| | | - [default.get_default_transforms](data_api.md#pyearthtools.data.transforms.default.get_default_transforms) |
Expand Down
303 changes: 152 additions & 151 deletions docs/api/pipeline/pipeline_index.md

Large diffs are not rendered by default.

79 changes: 77 additions & 2 deletions notebooks/tutorial/CNN-Model-Training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,80 @@
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"A module that was compiled using NumPy 1.x cannot be run in\n",
"NumPy 2.0.1 as it may crash. To support both 1.x and 2.x\n",
"versions of NumPy, modules must be compiled with NumPy 2.0.\n",
"Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.\n",
"\n",
"If you are a user of the module, the easiest solution will be to\n",
"downgrade to 'numpy<2' or try to upgrade the affected module.\n",
"We expect that some modules will need time to support NumPy 2.\n",
"\n",
"Traceback (most recent call last): File \"<frozen runpy>\", line 198, in _run_module_as_main\n",
" File \"<frozen runpy>\", line 88, in _run_code\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel_launcher.py\", line 18, in <module>\n",
" app.launch_new_instance()\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n",
" app.start()\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel/kernelapp.py\", line 739, in start\n",
" self.io_loop.start()\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/tornado/platform/asyncio.py\", line 205, in start\n",
" self.asyncio_loop.run_forever()\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/asyncio/base_events.py\", line 645, in run_forever\n",
" self._run_once()\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/asyncio/base_events.py\", line 1999, in _run_once\n",
" handle._run()\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/asyncio/events.py\", line 88, in _run\n",
" self._context.run(self._callback, *self._args)\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 545, in dispatch_queue\n",
" await self.process_one()\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 534, in process_one\n",
" await dispatch(*args)\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 437, in dispatch_shell\n",
" await result\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 362, in execute_request\n",
" await super().execute_request(stream, ident, parent)\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 778, in execute_request\n",
" reply_content = await reply_content\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 449, in do_execute\n",
" res = shell.run_cell(\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n",
" return super().run_cell(*args, **kwargs)\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3098, in run_cell\n",
" result = self._run_cell(\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3153, in _run_cell\n",
" result = runner(coro)\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/IPython/core/async_helpers.py\", line 128, in _pseudo_sync_runner\n",
" coro.send(None)\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3365, in run_cell_async\n",
" has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3610, in run_ast_nodes\n",
" if await self.run_code(code, result, async_=asy):\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3670, in run_code\n",
" exec(code_obj, self.user_global_ns, self.user_ns)\n",
" File \"/var/folders/1s/z56f8rw50755xx8fxp2969r477wmss/T/ipykernel_48096/575015278.py\", line 7, in <module>\n",
" import torch\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/torch/__init__.py\", line 1477, in <module>\n",
" from .functional import * # noqa: F403\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/torch/functional.py\", line 9, in <module>\n",
" import torch.nn.functional as F\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/torch/nn/__init__.py\", line 1, in <module>\n",
" from .modules import * # noqa: F403\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/torch/nn/modules/__init__.py\", line 35, in <module>\n",
" from .transformer import TransformerEncoder, TransformerDecoder, \\\n",
" File \"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/torch/nn/modules/transformer.py\", line 20, in <module>\n",
" device: torch.device = torch.device(torch._C._get_default_device()), # torch.device('cpu'),\n",
"/Users/masonge/Documents/PET/venv/lib/python3.12/site-packages/torch/nn/modules/transformer.py:20: UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:84.)\n",
" device: torch.device = torch.device(torch._C._get_default_device()), # torch.device('cpu'),\n"
]
}
],
"source": [
"import sys\n",
"from pathlib import Path\n",
Expand Down Expand Up @@ -130,6 +203,7 @@
"metadata": {},
"outputs": [
{

"data": {
"text/html": [
"<div><svg style=\"position: absolute; width: 0; height: 0; overflow: hidden\">\n",
Expand Down Expand Up @@ -625,6 +699,7 @@
},
"metadata": {},
"output_type": "display_data"

}
],
"source": [
Expand Down Expand Up @@ -5976,7 +6051,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.5"
"version": "3.12.9"
},
"nbsphinx": {
"orphan": true
Expand Down
2 changes: 1 addition & 1 deletion notebooks/tutorial/HimawariAllBands.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.12.9"
},
"nbsphinx": {
"orphan": true
Expand Down
196 changes: 1 addition & 195 deletions packages/data/src/pyearthtools/data/transforms/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
import pandas as pd


from pyearthtools.data.transforms.transform import Transform, TransformCollection
from pyearthtools.data.transforms.attributes import SetType
from pyearthtools.data.transforms.transform import Transform
from pyearthtools.data.warnings import pyearthtoolsDataWarning
from pyearthtools.data.exceptions import DataNotFoundError

Expand Down Expand Up @@ -396,199 +395,6 @@ def apply(self, dataset: xr.Dataset) -> xr.Dataset:
return dataset


def weak_cast_to_int(value):
"""
Basically, turns integer floats to int types, otherwise
does nothing.
"""
try:
if int(value) == value:
value = int(value)
except Exception:
pass
return value


class Flatten(Transform):
"""Operation to flatten a coordinate in a dataset, putting the data at each value of the coordinate into a separate
data variable."""

def __init__(
self, coordinate: Hashable | list[Hashable] | tuple[Hashable], *extra_coordinates, skip_missing: bool = False
):
"""

Flatten a coordinate in an xarray Dataset, putting the data at each value of the coordinate into a separate
data variable.

The output data variables will be named "<old variable name><value of coordinate>". For example, if the input
Dataset has a variable "t" and it is flattened along the coordinate "pressure_level" which has values
[100, 200, 500], then the output Dataset will have variables called t100, t200 and t500.

If more than one coordinate is flattened, the output data variable names will concatenate the values of each
coordinate.

Args:
coordinate (Hashable | list[Hashable] | tuple[Hashable] | None):
Coordinates to flatten, either str or list of candidates.
*extra_coordinates (optional):
Arguments form of `coordinate`.
skip_missing (bool, optional):
Whether to skip data that does not have any of the listed coordinates. If True, will return such data
unchanged. Defaults to False.

Raises:
ValueError:
If invalid number of coordinates found

"""
super().__init__()
self.record_initialisation()

coordinate = coordinate if isinstance(coordinate, (list, tuple)) else [coordinate]
coordinate = [*coordinate, *extra_coordinates]

self._coordinate = coordinate
self._skip_missing = skip_missing

# @property
# def _info_(self):
# return dict(coordinate=self._coordinate, skip_missing=self._skip_missing)

def apply(self, dataset: xr.Dataset) -> xr.Dataset:
discovered_coord = list(set(self._coordinate).intersection(set(dataset.coords)))

if len(discovered_coord) == 0:
if self._skip_missing:
return dataset

raise ValueError(
f"{self._coordinate} could not be found in dataset with coordinates {list(dataset.coords)}.\n"
"Set 'skip_missing' to True to skip this."
)

elif len(discovered_coord) > 1:
transforms = TransformCollection(*[Flatten(coord) for coord in discovered_coord])
return transforms(dataset)

discovered_coord = str(discovered_coord[0])

coords = dataset.coords
new_ds = xr.Dataset(coords={co: v for co, v in coords.items() if not co == discovered_coord})
new_ds.attrs.update(
{f"{discovered_coord}-dtype": str(dataset[discovered_coord].encoding.get("dtype", "int32"))}
)

for var in dataset:
if discovered_coord not in dataset[var].coords:
new_ds[var] = dataset[var]
continue

coord_size = dataset[var][discovered_coord].values
coord_size = coord_size if isinstance(coord_size, np.ndarray) else np.array(coord_size)

if coord_size.size == 1 and False:
coord_val = weak_cast_to_int(dataset[var][discovered_coord].values)
new_ds[f"{var}{coord_val}"] = Drop(discovered_coord, ignore_missing=True)(dataset[var])

else:
for coord_val in dataset[discovered_coord]:
coord_val = weak_cast_to_int(coord_val.values.item())

selected = dataset[var].sel(**{discovered_coord: coord_val}) # type: ignore
selected = selected.drop_vars(discovered_coord) # type: ignore
selected.attrs.update(**{discovered_coord: coord_val})

new_ds[f"{var}{coord_val}"] = selected
return new_ds


class Expand(Transform):
"""Inverse operation to `Flatten`"""

def __init__(self, coordinate: Hashable | list[Hashable] | tuple[Hashable], *extra_coordinates):
"""
Inverse operation to [flatten][pyearthtools.data.transforms.coordinate.Flatten]

Will find flattened variables and regroup them upon the extra coordinate

Args:
coordinate (Hashable | list[Hashable] | tuple[Hashable]):
Coordinate to unflatten.
*extra_coordinates (optional):
Argument form of `coordinate`.
"""
super().__init__()
self.record_initialisation()

if not isinstance(coordinate, (list, tuple)):
coordinate = (coordinate,)

coordinate = (*coordinate, *extra_coordinates)
self._coordinate = coordinate

# @property
# def _info_(self):
# return dict(coordinate=self._coordinate)

def apply(self, dataset: xr.Dataset) -> xr.Dataset | xr.DataArray:
dataset = type(dataset)(dataset)

for coord in self._coordinate:
dtype = dataset.attrs.get(f"{coord}-dtype", "int32")
components = []
for var in list(dataset.data_vars):
var_data = dataset[var]
if coord in var_data.attrs:
value = var_data.attrs.pop(coord)
var_data = (
var_data.to_dataset(name=var.replace(str(value), ""))
.assign_coords(**{coord: [value]})
.set_coords(coord)
)
components.append(var_data)

dataset = xr.combine_by_coords(components) # type: ignore
dataset = SetType(**{str(coord): dtype})(dataset)

## Add stored encoding if there
if f"{coord}-dtype" in dataset.attrs:
dtype = dataset.attrs.pop(f"{coord}-dtype")
dataset[coord].encoding.update(dtype=dtype)

return dataset


def SelectFlatten(
coordinates: dict[str, tuple[Any] | Any] | None = None,
tolerance: float = 0.01,
**extra_coordinates,
) -> TransformCollection:
"""
Select upon coordinates, and flatten said coordinate

Args:
coordinates (dict[str, tuple[Any] | Any] | None, optional):
Coordinates and values to select.
Must be coordinate in data Defaults to None.
tolerance (float, optional):
tolerance of selection. Defaults to 0.01.

Returns:
(TransformCollection):
TransformCollection to select and Flatten
"""

if coordinates is None:
coordinates = {}
coordinates.update(extra_coordinates)

select_trans = Select(coordinates, ignore_missing=True, tolerance=tolerance)
flatten_trans = Flatten(list(coordinates.keys()))

return select_trans + flatten_trans


class Assign(Transform):
"""Assign coordinates to object"""

Expand Down
Loading