Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion bluemath_tk/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,14 @@ def wrapper(
raise ValueError("Number of iterations must be integer and > 0")
if not isinstance(normalize_data, bool):
raise TypeError("Normalize data must be a boolean")
return func(self, data, directional_variables, num_iteration)
return func(
self,
data,
directional_variables,
custom_scale_factor,
num_iteration,
normalize_data,
)

return wrapper

Expand Down
9 changes: 8 additions & 1 deletion bluemath_tk/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def normalize(
... }
... )
>>> normalized_data, scale_factor = normalize(data=df)

>>> import numpy as np
>>> import xarray as xr
>>> from bluemath_tk.core.data import normalize
>>> from bluemath_tk.core.operations import normalize
>>> ds = xr.Dataset(
... {
... "Hs": (("time",), np.random.rand(1000) * 7),
Expand All @@ -85,6 +86,7 @@ def normalize(
vars_to_normalize = list(data.data_vars)
else:
raise TypeError("Data must be a pandas DataFrame or an xarray Dataset")

normalized_data = data.copy() # Copy data to avoid bad memory replacements
scale_factor = (
custom_scale_factor.copy()
Expand Down Expand Up @@ -122,6 +124,7 @@ def normalize(
normalized_data[data_var] = (normalized_data[data_var] - data_var_min) / (
data_var_max - data_var_min
)

return normalized_data, scale_factor


Expand Down Expand Up @@ -173,6 +176,7 @@ def denormalize(
... "Dir": [0, 360],
... }
>>> denormalized_data = denormalize(normalized_data=df, scale_factor=scale_factor)

>>> import numpy as np
>>> import xarray as xr
>>> from bluemath_tk.core.operations import denormalize
Expand Down Expand Up @@ -204,6 +208,7 @@ def denormalize(
data[data_var] * (scale_factor[data_var][1] - scale_factor[data_var][0])
+ scale_factor[data_var][0]
)

return data


Expand Down Expand Up @@ -263,6 +268,7 @@ def standarize(
},
coords=data.coords,
)

return standarized_data, scaler


Expand Down Expand Up @@ -308,6 +314,7 @@ def destandarize(
},
coords=standarized_data.coords,
)

return data


Expand Down
83 changes: 83 additions & 0 deletions bluemath_tk/core/plotting/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import List, Tuple

import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from .base_plotting import DefaultStaticPlotting
from .colors import default_colors


def plot_scatters_in_triangle(
dataframes: List[pd.DataFrame],
data_colors: List[str] = default_colors,
**kwargs,
) -> Tuple[Figure, Axes]:
"""
Plot a scatter plot of the dataframes with axes in a triangle.

Parameters
----------
dataframes : List[pd.DataFrame]
List of dataframes to plot.
data_colors : List[str], optional
List of colors for the dataframes.
**kwargs : dict, optional
Keyword arguments for the scatter plot. Will be passed to the
DefaultStaticPlotting.plot_scatter method, which is the same
as the one in matplotlib.pyplot.scatter.
For example, to change the marker size, you can use:
``plot_scatters_in_triangle(dataframes, s=10)``

Returns
-------
fig : Figure
Figure object.
axes : Axes
Axes object.
"""

# Get the number and names of variables from the first dataframe
variables_names = list(dataframes[0].columns)
num_variables = len(variables_names)

# Check variables names are in all dataframes
for df in dataframes:
if not all(v in df.columns for v in variables_names):
raise ValueError(
f"Variables {variables_names} are not in dataframe {df.columns}."
)

# Create figure and axes
default_static_plot = DefaultStaticPlotting()
fig, axes = default_static_plot.get_subplots(
nrows=num_variables - 1,
ncols=num_variables - 1,
sharex=False,
sharey=False,
)
if isinstance(axes, Axes):
axes = np.array([[axes]])

for c1, v1 in enumerate(variables_names[1:]):
for c2, v2 in enumerate(variables_names[:-1]):
for idf, df in enumerate(dataframes):
default_static_plot.plot_scatter(
ax=axes[c2, c1],
x=df[v1],
y=df[v2],
c=data_colors[idf],
alpha=0.6,
**kwargs,
)
if c1 == c2:
axes[c2, c1].set_xlabel(variables_names[c1 + 1])
axes[c2, c1].set_ylabel(variables_names[c2])
elif c1 > c2:
axes[c2, c1].xaxis.set_ticklabels([])
axes[c2, c1].yaxis.set_ticklabels([])
else:
fig.delaxes(axes[c2, c1])

return fig, axes
128 changes: 32 additions & 96 deletions bluemath_tk/datamining/_base_datamining.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from ..core.models import BlueMathModel
from ..core.plotting.base_plotting import DefaultStaticPlotting
from ..core.plotting.scatter import plot_scatters_in_triangle


class BaseSampling(BlueMathModel):
"""
Base class for all sampling BlueMath models.
This class provides the basic structure for all sampling models.

Methods
-------
generate : pd.DataFrame
Generates samples.
plot_generated_data : Tuple[plt.figure, plt.axes]
Plots the generated data on a scatter plot matrix.
"""

@abstractmethod
Expand Down Expand Up @@ -52,7 +46,7 @@ def plot_generated_data(
self,
data_color: str = "blue",
**kwargs,
) -> Tuple[plt.figure, plt.axes]:
) -> Tuple[Figure, Axes]:
"""
Plots the generated data on a scatter plot matrix.

Expand All @@ -65,9 +59,9 @@ def plot_generated_data(

Returns
-------
plt.figure
Figure
The figure object containing the plot.
plt.axes
Axes
Array of axes objects for the subplots.

Raises
Expand All @@ -76,40 +70,14 @@ def plot_generated_data(
If the data is empty.
"""

if not self.data.empty:
variables_names = list(self.data.columns)
num_variables = len(variables_names)
else:
if self.data.empty:
raise ValueError("Data must be a non-empty DataFrame with columns to plot.")

# Create figure and axes
default_static_plot = DefaultStaticPlotting()
fig, axes = default_static_plot.get_subplots(
nrows=num_variables - 1,
ncols=num_variables - 1,
sharex=False,
sharey=False,
fig, axes = plot_scatters_in_triangle(
dataframes=[self.data],
data_colors=[data_color],
**kwargs,
)
if isinstance(axes, Axes):
axes = np.array([[axes]])

for c1, v1 in enumerate(variables_names[1:]):
for c2, v2 in enumerate(variables_names[:-1]):
default_static_plot.plot_scatter(
ax=axes[c2, c1],
x=self.data[v1],
y=self.data[v2],
c=data_color,
**kwargs,
)
if c1 == c2:
axes[c2, c1].set_xlabel(variables_names[c1 + 1])
axes[c2, c1].set_ylabel(variables_names[c2])
elif c1 > c2:
axes[c2, c1].xaxis.set_ticklabels([])
axes[c2, c1].yaxis.set_ticklabels([])
else:
fig.delaxes(axes[c2, c1])

return fig, axes

Expand Down Expand Up @@ -162,7 +130,7 @@ def fit(
self.custom_scale_factor = custom_scale_factor.copy()
else:
self.logger.info(
"Normalization is disabled. Using default scale factor (0, 1) for all fitting variables."
"Normalization is disabled. Set normalize_data to True to enable normalization."
)
self.custom_scale_factor = {
fitting_variable: (0, 1) for fitting_variable in self.fitting_variables
Expand Down Expand Up @@ -198,7 +166,7 @@ def plot_selected_centroids(
centroids_color: str = "red",
plot_text: bool = False,
**kwargs,
) -> Tuple[plt.figure, plt.axes]:
) -> Tuple[Figure, Axes]:
"""
Plots data and selected centroids on a scatter plot matrix.

Expand All @@ -215,9 +183,9 @@ def plot_selected_centroids(

Returns
-------
plt.figure
Figure
The figure object containing the plot.
plt.axes
Axes
Array of axes objects for the subplots.

Raises
Expand All @@ -231,59 +199,27 @@ def plot_selected_centroids(
and list(self.data.columns) != []
):
variables_names = list(self.data.columns)
num_variables = len(variables_names)
else:
raise ValueError(
"Data and centroids must have the same number of columns > 0."
)

# Create figure and axes
default_static_plot = DefaultStaticPlotting()
fig, axes = default_static_plot.get_subplots(
nrows=num_variables - 1,
ncols=num_variables - 1,
sharex=False,
sharey=False,
fig, axes = plot_scatters_in_triangle(
dataframes=[self.data, self.centroids],
data_colors=[data_color, centroids_color],
**kwargs,
)
if isinstance(axes, Axes):
axes = np.array([[axes]])

for c1, v1 in enumerate(variables_names[1:]):
for c2, v2 in enumerate(variables_names[:-1]):
default_static_plot.plot_scatter(
ax=axes[c2, c1],
x=self.data[v1],
y=self.data[v2],
c=data_color,
alpha=0.6,
**kwargs,
)
if self.centroids is not None:
default_static_plot.plot_scatter(
ax=axes[c2, c1],
x=self.centroids[v1],
y=self.centroids[v2],
c=centroids_color,
alpha=0.9,
**kwargs,
)
if plot_text:
for i in range(self.centroids.shape[0]):
axes[c2, c1].text(
self.centroids[v1][i],
self.centroids[v2][i],
str(i + 1),
fontsize=12,
fontweight="bold",
)
if c1 == c2:
axes[c2, c1].set_xlabel(variables_names[c1 + 1])
axes[c2, c1].set_ylabel(variables_names[c2])
elif c1 > c2:
axes[c2, c1].xaxis.set_ticklabels([])
axes[c2, c1].yaxis.set_ticklabels([])
else:
fig.delaxes(axes[c2, c1])
if plot_text:
for c1, v1 in enumerate(variables_names[1:]):
for c2, v2 in enumerate(variables_names[:-1]):
for i in range(self.centroids.shape[0]):
axes[c2, c1].text(
self.centroids[v1][i],
self.centroids[v2][i],
str(i + 1),
fontsize=12,
fontweight="bold",
)

return fig, axes

Expand All @@ -292,7 +228,7 @@ def plot_data_as_clusters(
data: pd.DataFrame,
nearest_centroids: np.ndarray,
**kwargs,
) -> Tuple[plt.figure, plt.axes]:
) -> Tuple[Figure, Axes]:
"""
Plots data as nearest clusters.

Expand All @@ -307,9 +243,9 @@ def plot_data_as_clusters(

Returns
-------
plt.figure
Figure
The figure object containing the plot.
plt.axes
Axes
The axes object for the plot.
"""

Expand Down
Loading