Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,27 @@ on:
- synchronize

jobs:
# Separate linting job for faster feedback
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[tests]

- name: Lint with ruff
run: |
ruff check bluemath_tk/ || true # TODO: Remove || true once docstrings are fixed

python-tests:
runs-on: ${{ matrix.os }}

Expand All @@ -32,10 +53,6 @@ jobs:
python -m pip install --upgrade pip
pip install .[all,tests]

- name: Lint
run: |
ruff check bluemath_tk/datamining/ || true # optional: don't fail lint for now

- name: Run tests
run: |
pytest -s -v tests
82 changes: 51 additions & 31 deletions bluemath_tk/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
"""
Validation decorators for BlueMath_tk classes.

This module provides decorators to validate input data for various
clustering, reduction, and analysis methods.
"""

import functools
from typing import Any, Dict, List
from typing import Any

import pandas as pd
import xarray as xr


def validate_data_lhs(func):
"""
Decorator to validate data in LHS class fit method.
Validate data in LHS class fit method.

Parameters
----------
Expand All @@ -23,9 +30,9 @@ def validate_data_lhs(func):
@functools.wraps(func)
def wrapper(
self,
dimensions_names: List[str],
lower_bounds: List[float],
upper_bounds: List[float],
dimensions_names: list[str],
lower_bounds: list[float],
upper_bounds: list[float],
num_samples: int,
):
if not isinstance(dimensions_names, list):
Expand All @@ -38,7 +45,8 @@ def wrapper(
upper_bounds
):
raise ValueError(
"Dimensions names, lower bounds and upper bounds must have the same length"
"Dimensions names, lower bounds and upper bounds "
"must have the same length"
)
if not all(
[lower <= upper for lower, upper in zip(lower_bounds, upper_bounds)]
Expand All @@ -53,7 +61,7 @@ def wrapper(

def validate_data_mda(func):
"""
Decorator to validate data in MDA class fit method.
Validate data in MDA class fit method.

Parameters
----------
Expand All @@ -70,7 +78,7 @@ def validate_data_mda(func):
def wrapper(
self,
data: pd.DataFrame,
directional_variables: List[str] = [],
directional_variables: list[str] = [],
custom_scale_factor: dict = {},
first_centroid_seed: int = None,
normalize_data: bool = False,
Expand All @@ -90,7 +98,8 @@ def wrapper(
or first_centroid_seed > data.shape[0]
):
raise ValueError(
"First centroid seed must be an integer >= 0 and < num of data points"
"First centroid seed must be an integer >= 0 "
"and < num of data points"
)
if not isinstance(normalize_data, bool):
raise TypeError("Normalize data must be a boolean")
Expand All @@ -108,7 +117,7 @@ def wrapper(

def validate_data_kma(func):
"""
Decorator to validate data in KMA class fit method.
Validate data in KMA class fit method.

Parameters
----------
Expand All @@ -125,12 +134,13 @@ def validate_data_kma(func):
def wrapper(
self,
data: pd.DataFrame,
directional_variables: List[str] = [],
directional_variables: list[str] = [],
custom_scale_factor: dict = {},
min_number_of_points: int = None,
max_number_of_iterations: int = 10,
normalize_data: bool = False,
regression_guided: Dict[str, List] = {},
regression_guided: dict[str, list] = {},
init_mda_centroids: pd.DataFrame = None,
):
if data is None:
raise ValueError("data cannot be None")
Expand All @@ -157,7 +167,8 @@ def wrapper(
for var in regression_guided.get("vars", [])
):
raise TypeError(
"regression_guided vars must be a list of strings and must exist in data"
"regression_guided vars must be a list of strings "
"and must exist in data"
)
if not all(
isinstance(alpha, float) and alpha >= 0 and alpha <= 1
Expand All @@ -175,14 +186,15 @@ def wrapper(
max_number_of_iterations,
normalize_data,
regression_guided,
init_mda_centroids,
)

return wrapper


def validate_data_som(func):
"""
Decorator to validate data in SOM class fit method.
Validate data in SOM class fit method.

Parameters
----------
Expand All @@ -199,7 +211,7 @@ def validate_data_som(func):
def wrapper(
self,
data: pd.DataFrame,
directional_variables: List[str] = [],
directional_variables: list[str] = [],
custom_scale_factor: dict = {},
num_iteration: int = 1000,
normalize_data: bool = False,
Expand Down Expand Up @@ -230,7 +242,7 @@ def wrapper(

def validate_data_pca(func):
"""
Decorator to validate data in PCA class fit method.
Validate data in PCA class fit method.

Parameters
----------
Expand All @@ -247,8 +259,8 @@ def validate_data_pca(func):
def wrapper(
self,
data: xr.Dataset,
vars_to_stack: List[str],
coords_to_stack: List[str],
vars_to_stack: list[str],
coords_to_stack: list[str],
pca_dim_for_rows: str,
windows_in_pca_dim_for_rows: dict = {},
value_to_replace_nans: dict = {},
Expand All @@ -263,18 +275,22 @@ def wrapper(
for var in vars_to_stack:
if var not in data.data_vars:
raise ValueError(f"Variable {var} not found in data")
# Check that all variables in vars_to_stack have the same coordinates and dimensions
# Check that all variables in vars_to_stack have the same
# coordinates and dimensions
first_var = vars_to_stack[0]
first_var = vars_to_stack[0]
first_var_dims = list(data[first_var].dims)
first_var_coords = list(data[first_var].coords)
for var in vars_to_stack:
if list(data[var].dims) != first_var_dims:
raise ValueError(
f"All variables must have the same dimensions. Variable {var} does not match."
f"All variables must have the same dimensions. "
f"Variable {var} does not match."
)
if list(data[var].coords) != first_var_coords:
raise ValueError(
f"All variables must have the same coordinates. Variable {var} does not match."
f"All variables must have the same coordinates. "
f"Variable {var} does not match."
)
# Check that all coords_to_stack are in the data
if not isinstance(coords_to_stack, list) or len(coords_to_stack) == 0:
Expand All @@ -285,7 +301,8 @@ def wrapper(
# Check that pca_dim_for_rows is in the data, and window > 0 if provided
if not isinstance(pca_dim_for_rows, str) or pca_dim_for_rows not in data.dims:
raise ValueError(
"PCA dimension for rows must be a string and found in the data dimensions"
"PCA dimension for rows must be a string "
"and found in the data dimensions"
)
for variable, windows in windows_in_pca_dim_for_rows.items():
if not isinstance(windows, list):
Expand Down Expand Up @@ -314,7 +331,7 @@ def wrapper(

def validate_data_rbf(func):
"""
Decorator to validate data in RBF class fit method.
Validate data in RBF class fit method.

Parameters
----------
Expand All @@ -332,8 +349,8 @@ def wrapper(
self,
subset_data: pd.DataFrame,
target_data: pd.DataFrame,
subset_directional_variables: List[str] = [],
target_directional_variables: List[str] = [],
subset_directional_variables: list[str] = [],
target_directional_variables: list[str] = [],
subset_custom_scale_factor: dict = {},
normalize_target_data: bool = True,
target_custom_scale_factor: dict = {},
Expand All @@ -353,14 +370,16 @@ def wrapper(
for directional_variable in subset_directional_variables:
if directional_variable not in subset_data.columns:
raise ValueError(
f"Directional variable {directional_variable} not found in subset data"
f"Directional variable {directional_variable} "
f"not found in subset data"
)
if not isinstance(target_directional_variables, list):
raise TypeError("Target directional variables must be a list")
for directional_variable in target_directional_variables:
if directional_variable not in target_data.columns:
raise ValueError(
f"Directional variable {directional_variable} not found in target data"
f"Directional variable {directional_variable} "
f"not found in target data"
)
if not isinstance(subset_custom_scale_factor, dict):
raise TypeError("Subset custom scale factor must be a dict")
Expand Down Expand Up @@ -391,7 +410,7 @@ def wrapper(

def validate_data_xwt(func):
"""
Decorator to validate data in XWT class fit method.
Validate data in XWT class fit method.

Parameters
----------
Expand All @@ -408,7 +427,7 @@ def validate_data_xwt(func):
def wrapper(
self,
data: xr.Dataset,
fit_params: Dict[str, Dict[str, Any]] = {},
fit_params: dict[str, dict[str, Any]] = {},
variable_to_sort_bmus: str = None,
):
if not isinstance(data, xr.Dataset):
Expand All @@ -427,7 +446,8 @@ def wrapper(
or variable_to_sort_bmus not in data.data_vars
):
raise TypeError(
"variable_to_sort_bmus must be a string and must exist in data variables"
"variable_to_sort_bmus must be a string "
"and must exist in data variables"
)
return func(
self,
Expand All @@ -441,7 +461,7 @@ def wrapper(

def validate_data_calval(func):
"""
Decorator to validate data in CalVal class fit method.
Validate data in CalVal class fit method.

Parameters
----------
Expand Down
Loading
Loading