Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ jobs:
source /usr/share/miniconda/etc/profile.d/conda.sh
conda activate bluemath-tk
python -m unittest discover tests/datamining/
python -m unittest discover tests/downloaders/
python -m unittest discover tests/distributions/
python -m unittest discover tests/interpolation/
python -m unittest discover tests/wrappers/
56 changes: 56 additions & 0 deletions bluemath_tk/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,59 @@ def wrapper(
)

return wrapper


def validate_data_calval(func):
"""
Decorator to validate data in CalVal class fit method.

Parameters
----------
func : callable
The function to be decorated

Returns
-------
callable
The decorated function
"""

@functools.wraps(func)
def wrapper(
self,
data: pd.DataFrame,
data_longitude: float,
data_latitude: float,
data_to_calibrate: pd.DataFrame,
max_time_diff: int = 2,
):
if not isinstance(data, pd.DataFrame):
raise TypeError("Data must be a pandas DataFrame")
if not isinstance(data_longitude, float):
raise TypeError("Longitude must be a float")
data_longitude = data_longitude % 360
if not isinstance(data_latitude, float):
raise TypeError("Latitude must be a float")
if not isinstance(data_to_calibrate, pd.DataFrame):
raise TypeError("Data to calibrate must be a pandas DataFrame")
if "LONGITUDE" not in data_to_calibrate.columns:
raise ValueError(
"Data to calibrate must contain a column named 'LONGITUDE'"
)
if "LATITUDE" not in data_to_calibrate.columns:
raise ValueError("Data to calibrate must contain a column named 'LATITUDE'")
if "Hs_CAL" not in data_to_calibrate.columns:
raise ValueError("Data to calibrate must contain a column named 'Hs_CAL'")
if not isinstance(max_time_diff, int) or max_time_diff <= 0:
raise ValueError("Maximum time difference must be an integer and > 0")

return func(
self,
data,
data_longitude,
data_latitude,
data_to_calibrate,
max_time_diff,
)

return wrapper
225 changes: 209 additions & 16 deletions bluemath_tk/core/plotting/scatter.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,236 @@
from typing import List, Tuple
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from scipy.stats import gaussian_kde, probplot
from sklearn.metrics import mean_squared_error

from .base_plotting import DefaultStaticPlotting
from .colors import default_colors


def rmse(pred: np.ndarray, tar: np.ndarray) -> float:
"""
Calculate the Root Mean Square Error between predicted and target values.

Parameters
----------
pred : np.ndarray
Array of predicted values.
tar : np.ndarray
Array of target/actual values.

Returns
-------
float
The Root Mean Square Error value.
"""

if len(pred) != len(tar):
raise ValueError("pred and tar must have the same length")

return np.sqrt(((pred - tar) ** 2).mean())


def bias(pred: np.ndarray, tar: np.ndarray) -> float:
"""
Calculate the bias between predicted and target values.

Parameters
----------
pred : np.ndarray
Array of predicted values.
tar : np.ndarray
Array of target/actual values.

Returns
-------
float
The bias value (mean difference between predictions and targets).
"""

if len(pred) != len(tar):
raise ValueError("pred and tar must have the same length")

return sum(pred - tar) / len(pred)


def si(pred: np.ndarray, tar: np.ndarray) -> float:
"""
Calculate the Scatter Index between predicted and target values.

Parameters
----------
pred : np.ndarray
Array of predicted values.
tar : np.ndarray
Array of target/actual values.

Returns
-------
float
The Scatter Index value.
"""

if len(pred) != len(tar):
raise ValueError("pred and tar must have the same length")

pred_mean = pred.mean()
tar_mean = tar.mean()

return np.sqrt(sum(((pred - pred_mean) - (tar - tar_mean)) ** 2) / (sum(tar**2)))


def density_scatter(
x: np.ndarray, y: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute a density scatter for two arrays using gaussian KDE.

Parameters
----------
x : np.ndarray
X values for the scatter plot.
y : np.ndarray
Y values for the scatter plot.

Returns
-------
Tuple[np.ndarray, np.ndarray, np.ndarray]
A tuple containing:
- Sorted x values
- Sorted y values
- Density values corresponding to each point
"""

if len(x) != len(y):
raise ValueError("x and y must have the same length")

xy = np.vstack([x, y])
z = gaussian_kde(xy)(xy)
idx = z.argsort()
x1, y1, z = x[idx], y[idx], z[idx]

return x1, y1, z


def validation_scatter(
axs: Axes,
x: np.ndarray,
y: np.ndarray,
xlabel: str,
ylabel: str,
title: str,
) -> None:
"""
Plot a density scatter and Q-Q plot for validation.

Parameters
----------
axs : Axes
Matplotlib axes to plot on.
x : np.ndarray
X values for the scatter plot.
y : np.ndarray
Y values for the scatter plot.
xlabel : str
Label for the X-axis.
ylabel : str
Label for the Y-axis.
title : str
Title for the plot.
"""

x2, y2, z = density_scatter(x, y)

# plot
axs.scatter(x2, y2, c=z, s=5, cmap="rainbow")

# labels
axs.set_xlabel(xlabel)
axs.set_ylabel(ylabel)
axs.set_title(title)

# axis limits
maxt = np.ceil(max(max(x) + 0.1, max(y) + 0.1))
axs.set_xlim(0, maxt)
axs.set_ylim(0, maxt)
axs.plot([0, maxt], [0, maxt], "-r")
axs.set_xticks(np.linspace(0, maxt, 5))
axs.set_yticks(np.linspace(0, maxt, 5))
axs.set_aspect("equal")

# qq-plot
xq = probplot(x, dist="norm")
yq = probplot(y, dist="norm")
axs.plot(xq[0][1], yq[0][1], "o", markersize=0.5, color="k", label="Q-Q plot")

# diagnostic errors
props = dict(
boxstyle="round", facecolor="w", edgecolor="grey", linewidth=0.8, alpha=0.5
)
mse = mean_squared_error(x2, y2)
rmse_e = rmse(x2, y2)
BIAS = bias(x2, y2)
SI = si(x2, y2)
label = "\n".join(
(
r"RMSE = %.2f" % (rmse_e,),
r"mse = %.2f" % (mse,),
r"BIAS = %.2f" % (BIAS,),
R"SI = %.2f" % (SI,),
)
)
axs.text(
0.05,
0.95,
label,
transform=axs.transAxes,
fontsize=9,
verticalalignment="top",
bbox=props,
)


def plot_scatters_in_triangle(
dataframes: List[pd.DataFrame],
data_colors: List[str] = default_colors,
data_colors: Optional[List[str]] = None,
**kwargs,
) -> Tuple[Figure, Axes]:
) -> Tuple[Figure, np.ndarray]:
"""
Plot a scatter plot of the dataframes with axes in a triangle.
Plot scatter plots of the dataframes with axes in a triangle arrangement.

Parameters
----------
dataframes : List[pd.DataFrame]
List of dataframes to plot.
data_colors : List[str], optional
List of colors for the dataframes.
List of dataframes to plot. Each dataframe should contain the same columns.
data_colors : Optional[List[str]], optional
List of colors for the dataframes. If None, uses default_colors.
**kwargs : dict, optional
Keyword arguments for the scatter plot. Will be passed to the
DefaultStaticPlotting.plot_scatter method, which is the same
as the one in matplotlib.pyplot.scatter.
For example, to change the marker size, you can use:
``plot_scatters_in_triangle(dataframes, s=10)``
Additional keyword arguments for the scatter plot. These will be passed to
matplotlib.pyplot.scatter. Common parameters include:
- s : float, marker size
- alpha : float, transparency
- marker : str, marker style

Returns
-------
fig : Figure
Figure object.
axes : Axes
Axes object.
Tuple[Figure, np.ndarray]
A tuple containing:
- Figure object
- 2D array of Axes objects

Raises
------
ValueError
If the variables in the first dataframe are not present in all other dataframes.
"""

if data_colors is None:
data_colors = default_colors

# Get the number and names of variables from the first dataframe
variables_names = list(dataframes[0].columns)
num_variables = len(variables_names)
Expand Down
Loading