Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion bluemath_tk/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import List
from typing import Any, Dict, List

import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -120,6 +120,7 @@ def wrapper(
directional_variables: List[str] = [],
custom_scale_factor: dict = {},
min_number_of_points: int = None,
normalize_data: bool = True,
):
if data is None:
raise ValueError("Data cannot be None")
Expand All @@ -132,6 +133,8 @@ def wrapper(
if min_number_of_points is not None:
if not isinstance(min_number_of_points, int) or min_number_of_points <= 0:
raise ValueError("Minimum number of points must be integer and > 0")
if not isinstance(normalize_data, bool):
raise TypeError("Normalize data must be a boolean")
return func(
self, data, directional_variables, custom_scale_factor, min_number_of_points
)
Expand Down Expand Up @@ -331,3 +334,43 @@ def wrapper(
)

return wrapper


def validate_data_xwt(func):
"""
Decorator to validate data in XWT class fit method.

Parameters
----------
func : callable
The function to be decorated

Returns
-------
callable
The decorated function
"""

@functools.wraps(func)
def wrapper(
self,
data: xr.Dataset,
fit_params: Dict[str, Dict[str, Any]] = {},
):
if not isinstance(data, xr.Dataset):
raise TypeError("Data must be an xarray Dataset")
if "time" not in data.dims:
raise ValueError(
'Time dimension with name "time" not found in data, rename and re-fit'
) # TODO: check time is actually datetime
if not isinstance(fit_params, dict):
raise TypeError("Fit params must be a dict")
if "pca" not in fit_params:
raise ValueError("Fit params must contain PCA parameters")
return func(
self,
data,
fit_params,
)

return wrapper
5 changes: 4 additions & 1 deletion bluemath_tk/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def get_file_logger(
logs_path: str = None,
level: Union[int, str] = "INFO",
console: bool = True,
console_level: Union[int, str] = "WARNING",
) -> logging.Logger:
"""
Creates and returns a logger that writes log messages to a file.
Expand All @@ -25,7 +26,8 @@ def get_file_logger(
The logging level. Default is "INFO".
console : bool
Whether to add or not console / terminal logs. Default is True.

console_level : Union[int, str], optional
The logging level for console / terminal logs. Default is "WARNING".
Returns
-------
logging.Logger
Expand Down Expand Up @@ -78,6 +80,7 @@ def get_file_logger(
# Also ouput logs in the console if requested
if console:
console_handler = logging.StreamHandler()
console_handler.setLevel(console_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

Expand Down
47 changes: 47 additions & 0 deletions bluemath_tk/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import pandas as pd
import xarray as xr


class BlueMathPipeline:
"""
This is the BlueMathPipeline class!
"""

def __init__(self, steps: List[Tuple[str, Any]]):
"""
Initialize the BlueMathPipeline.

Parameters
----------
steps : List[Tuple[str, Any]]
A list of tuples where each tuple contains the name of the step and
the model instance.
"""

self.steps = steps

def fit(
self,
data: Union[np.ndarray, pd.DataFrame, xr.Dataset],
fit_params: Dict[str, Dict[str, Any]] = {},
):
"""
Fit the pipeline models.

Parameters
----------
data : Union[np.ndarray, pd.DataFrame, xr.Dataset]
The input data to fit the models.
fit_params : Dict[str, Dict[str, Any]], optional
A dictionary of parameters to pass to the fit method of each model.
The keys should be the names of the steps, and the values should be
dictionaries of parameters for the corresponding model's fit method.
"""

for name, model in self.steps:
params = fit_params.get(name, {})
if hasattr(model, "fit"):
model.fit(data, **params)
70 changes: 39 additions & 31 deletions bluemath_tk/datamining/kma.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ class KMA(BaseClustering):
The number of clusters to use in the K-Means algorithm.
seed : int
The random seed to use as initial datapoint.
data : pd.DataFrame
The input data.
normalized_data : pd.DataFrame
The normalized input data.
data_to_fit : pd.DataFrame
The data to fit the K-Means algorithm.
data_variables : List[str]
A list with all data variables.
directional_variables : List[str]
Expand All @@ -53,16 +47,6 @@ class KMA(BaseClustering):
centroid_real_indices : np.array
The real indices of the selected centroids.

Methods
-------
fit -> None
Fit the K-Means algorithm to the provided data.
predict -> Tuple[np.ndarray, pd.DataFrame]
Predict the nearest centroid for the provided data.
fit_predict -> Tuple[np.ndarray, pd.DataFrame]
Fit the K-Means algorithm to the provided data and predict the nearest
centroid for each data point.

Notes
-----
- The K-Means algorithm is used to cluster data points into k clusters.
Expand Down Expand Up @@ -119,7 +103,8 @@ def __init__(
"""

super().__init__()
self.set_logger_name(name=self.__class__.__name__)
self.set_logger_name(name=self.__class__.__name__, console=False)

if num_clusters > 0:
self.num_clusters = int(num_clusters)
else:
Expand All @@ -130,7 +115,6 @@ def __init__(
self.seed = int(seed)
else:
raise ValueError("Variable seed must be >= 0")
# TODO: check random_state and n_init
self._kma = KMeans(
n_clusters=self.num_clusters,
random_state=self.seed,
Expand All @@ -139,6 +123,7 @@ def __init__(
f"KMA object created with {self.num_clusters} clusters and seed {self.seed}."
"To customize kma, do self.kma = dict(n_clusters=..., random_state=..., etc)"
)

self._data: pd.DataFrame = pd.DataFrame()
self._normalized_data: pd.DataFrame = pd.DataFrame()
self._data_to_fit: pd.DataFrame = pd.DataFrame()
Expand All @@ -163,19 +148,31 @@ def kma(self) -> KMeans:
return self._kma

@kma.setter
def kma(self, **kwargs) -> None:
def kma(self, kwargs) -> None:
self._kma = KMeans(**kwargs)

@property
def data(self) -> pd.DataFrame:
"""
Returns the original data used for clustering.
"""

return self._data

@property
def normalized_data(self) -> pd.DataFrame:
"""
Returns the normalized data used for clustering.
"""

return self._normalized_data

@property
def data_to_fit(self) -> pd.DataFrame:
"""
Returns the data used for fitting the K-Means algorithm.
"""

return self._data_to_fit

@validate_data_kma
Expand All @@ -185,6 +182,7 @@ def fit(
directional_variables: List[str] = [],
custom_scale_factor: dict = {},
min_number_of_points: int = None,
normalize_data: bool = True,
) -> None:
"""
Fit the K-Means algorithm to the provided data.
Expand All @@ -206,6 +204,8 @@ def fit(
min_number_of_points : int, optional
The minimum number of points to consider a cluster.
Default is None.
normalize_data : bool, optional
A flag to normalize the data. Default is True.
"""

self._data = data.copy()
Expand All @@ -214,17 +214,22 @@ def fit(
u_comp, v_comp = self.get_uv_components(
x_deg=self.data[directional_variable].values
)
self.data[f"{directional_variable}_u"] = u_comp
self.data[f"{directional_variable}_v"] = v_comp
self._data[f"{directional_variable}_u"] = u_comp
self._data[f"{directional_variable}_v"] = v_comp
self.data_variables = list(self.data.columns)
self.custom_scale_factor = custom_scale_factor.copy()

# Get just the data to be used in the fitting
self._data_to_fit = self.data.copy()
for directional_variable in self.directional_variables:
self.data_to_fit.drop(columns=[directional_variable], inplace=True)
self.fitting_variables = list(self.data_to_fit.columns)

if normalize_data:
self.custom_scale_factor = custom_scale_factor.copy()
else:
self.custom_scale_factor = {
fitting_variable: (0, 1) for fitting_variable in self.fitting_variables
}
# Normalize data using custom min max scaler
self._normalized_data, self.scale_factor = self.normalize(
data=self.data_to_fit, custom_scale_factor=self.custom_scale_factor
Expand All @@ -235,10 +240,7 @@ def fit(
stable_kma_child = False
number_of_tries = 0
while not stable_kma_child:
kma_child = KMeans(
n_clusters=self.num_clusters,
random_state=self.seed,
)
kma_child = KMeans(n_clusters=self.num_clusters)
predicted_labels = kma_child.fit_predict(self.normalized_data)
_unique_labels, counts = np.unique(predicted_labels, return_counts=True)
if np.all(counts >= min_number_of_points):
Expand Down Expand Up @@ -272,7 +274,7 @@ def fit(
# Set the fitted flag to True
self.is_fitted = True

def predict(self, data: pd.DataFrame) -> Tuple[np.ndarray, pd.DataFrame]:
def predict(self, data: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Predict the nearest centroid for the provided data.

Expand All @@ -283,7 +285,7 @@ def predict(self, data: pd.DataFrame) -> Tuple[np.ndarray, pd.DataFrame]:

Returns
-------
Tuple[np.ndarray, pd.DataFrame]
Tuple[pd.DataFrame, pd.DataFrame]
A tuple containing the nearest centroid index for each data point,
and the nearest centroids.
"""
Expand All @@ -304,15 +306,18 @@ def predict(self, data: pd.DataFrame) -> Tuple[np.ndarray, pd.DataFrame]:
)
y = self.kma.predict(X=normalized_data)

return y, self.centroids.iloc[y]
return pd.DataFrame(
y, columns=["kma_bmus"], index=data.index
), self.centroids.iloc[y]

def fit_predict(
self,
data: pd.DataFrame,
directional_variables: List[str] = [],
custom_scale_factor: dict = {},
min_number_of_points: int = None,
) -> Tuple[np.ndarray, pd.DataFrame]:
normalize_data: bool = True,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Fit the K-Means algorithm to the provided data and predict the nearest centroid
for each data point.
Expand All @@ -330,10 +335,12 @@ def fit_predict(
min_number_of_points : int, optional
The minimum number of points to consider a cluster.
Default is None.
normalize_data : bool, optional
A flag to normalize the data. Default is True.

Returns
-------
Tuple[np.ndarray, pd.DataFrame]
Tuple[pd.DataFrame, pd.DataFrame]
A tuple containing the nearest centroid index for each data point,
and the nearest centroids.
"""
Expand All @@ -343,6 +350,7 @@ def fit_predict(
directional_variables=directional_variables,
custom_scale_factor=custom_scale_factor,
min_number_of_points=min_number_of_points,
normalize_data=normalize_data,
)

return self.predict(data=data)
Loading