Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 93 additions & 42 deletions ezyrb/plugin/scaler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Module for Scaler plugin """
"""Module for Scaler plugin"""

from .plugin import Plugin

Expand All @@ -19,9 +19,9 @@ class DatabaseScaler(Plugin):
applied at the full order ('full') or at the reduced one ('reduced').
:param {'parameters', 'snapshots'} params: define if the rescaling has to
be applied to the parameters or to the snapshots.

:Example:

>>> from ezyrb import ReducedOrderModel as ROM
>>> from ezyrb import POD, RBF, Database
>>> from ezyrb.plugin import DatabaseScaler
Expand All @@ -33,10 +33,11 @@ class DatabaseScaler(Plugin):
>>> rom = ROM(db, pod, rbf, plugins=[scaler])
>>> rom.fit()
"""

def __init__(self, scaler, mode, target) -> None:
"""
Initialize the DatabaseScaler plugin.

:param scaler: Scaler object with fit, transform, and inverse_transform methods.
:param str mode: 'full' or 'reduced' - where to apply the scaling.
:param str target: 'parameters' or 'snapshots' - what to scale.
Expand All @@ -46,7 +47,7 @@ def __init__(self, scaler, mode, target) -> None:
self.scaler = scaler
self.mode = mode
self.target = target

@property
def target(self):
"""
Expand All @@ -58,7 +59,7 @@ def target(self):

@target.setter
def target(self, new_target):
if new_target not in ['snapshots', 'parameters']:
if new_target not in ["snapshots", "parameters"]:
raise ValueError

self._target = new_target
Expand All @@ -74,102 +75,152 @@ def mode(self):

@mode.setter
def mode(self, new_mode):
if new_mode not in ['full', 'reduced']:
if new_mode not in ["full", "reduced"]:
raise ValueError

self._mode = new_mode

def _select_matrix(self, db):
"""
Helper function to select the proper matrix to rescale.

:param Database db: The database object.
:return: The selected matrix (parameters or snapshots).
"""
return getattr(db, f'{self.target}_matrix')
return getattr(db, f"{self.target}_matrix")

def rom_preprocessing(self, rom):
# =========================================================================
# MODE = 'FULL' - Scaling applied at full order (before reduction or after prediction)
# =========================================================================

def fit_before_reduction(self, rom):
"""
Apply scaling to the reduced database before ROM processing.

Apply scaling before POD reduction when mode='full'.
Scales the full-order database before reduction.

:param ReducedOrderModel rom: The ROM instance.
"""
if self.mode != 'reduced':
if self.mode != "full":
return

db = rom._reduced_database
db = rom.train_full_database

self.scaler.fit(self._select_matrix(db))

if self.target == 'parameters':
if self.target == "parameters":
new_db = type(db)(
self.scaler.transform(self._select_matrix(db)),
db.snapshots_matrix
db.snapshots_matrix,
)
else:
new_db = type(db)(
db.parameters_matrix,
self.scaler.transform(self._select_matrix(db)),
)

rom._reduced_database = new_db
rom.train_full_database = new_db

def fom_preprocessing(self, rom):
if self.mode != 'full':
return
def predict_postprocessing(self, rom):
"""
Inverse transform scaled data after prediction when mode='full'.
Restores original scale to the full-order predicted database.

db = rom._full_database
:param ReducedOrderModel rom: The ROM instance.
"""
if self.mode != "full":
return

self.scaler.fit(self._select_matrix(db))
db = rom.predicted_full_database

if self.target == 'parameters':
if self.target == "parameters":
new_db = type(db)(
self.scaler.transform(self._select_matrix(db)),
db.snapshots_matrix
self.scaler.inverse_transform(self._select_matrix(db)),
db.snapshots_matrix,
)
else:
new_db = type(db)(
db.parameters_matrix,
self.scaler.transform(self._select_matrix(db)),
self.scaler.inverse_transform(self._select_matrix(db)),
)

rom._full_database = new_db
rom.predicted_full_database = new_db

# =========================================================================
# MODE = 'REDUCED' - Scaling applied at reduced order (before/after approximation)
# =========================================================================

def fom_postprocessing(self, rom):
def fit_before_approximation(self, rom):
"""
Apply scaling before approximation training when mode='reduced'.
Scales the reduced database before approximation training.

if self.mode != 'full':
:param ReducedOrderModel rom: The ROM instance.
"""
if self.mode != "reduced":
return

db = rom._full_database
db = rom.train_reduced_database

if self.target == 'parameters':
self.scaler.fit(self._select_matrix(db))

if self.target == "parameters":
new_db = type(db)(
self.scaler.inverse_transform(self._select_matrix(db)),
db.snapshots_matrix
self.scaler.transform(self._select_matrix(db)),
db.snapshots_matrix,
)
else:
new_db = type(db)(
db.parameters_matrix,
self.scaler.inverse_transform(self._select_matrix(db)),
self.scaler.transform(self._select_matrix(db)),
)

rom._full_database = new_db
rom.train_reduced_database = new_db

def rom_postprocessing(self, rom):
if self.mode != 'reduced':
def predict_after_approximation(self, rom):
"""
Inverse transform scaled data after approximation when mode='reduced'.
Restores original scale to the reduced predicted database.

:param ReducedOrderModel rom: The ROM instance.
"""
if self.mode != "reduced":
return

db = rom._reduced_database
db = rom.predict_reduced_database

if self.target == 'parameters':
if self.target == "parameters":
new_db = type(db)(
self.scaler.inverse_transform(self._select_matrix(db)),
db.snapshots_matrix
db.snapshots_matrix,
)
else:
new_db = type(db)(
db.parameters_matrix,
self.scaler.inverse_transform(self._select_matrix(db)),
)

rom._reduced_database = new_db

rom.predict_reduced_database = new_db

# =========================================================================
# PREDICT - Scaling input parameters before approximation (both modes)
# =========================================================================

def predict_before_approximation(self, rom):
"""
Transform (scale) input parameters before approximation if target='parameters'.
This ensures parameters are scaled to match the training data.
Applied during prediction for both 'full' and 'reduced' modes.

:param ReducedOrderModel rom: The ROM instance.
"""
if self.target != "parameters":
return

db = rom.predict_reduced_database
transformed_params = self.scaler.transform(self._select_matrix(db))

# During prediction, snapshots are None (not yet predicted)
# Database constructor handles None snapshots: creates [None] * len(parameters)
new_db = type(db)(transformed_params, None)

rom.predict_reduced_database = new_db
120 changes: 102 additions & 18 deletions tests/test_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,124 @@

from sklearn.preprocessing import StandardScaler, MinMaxScaler

snapshots = np.load('tests/test_datasets/p_snapshots.npy').T
pred_sol_tst = np.load('tests/test_datasets/p_predsol.npy').T
pred_sol_gpr = np.load('tests/test_datasets/p_predsol_gpr.npy').T
param = np.array([[-.5, -.5], [.5, -.5], [.5, .5], [-.5, .5]])
snapshots = np.load("tests/test_datasets/p_snapshots.npy").T
pred_sol_tst = np.load("tests/test_datasets/p_predsol.npy").T
pred_sol_gpr = np.load("tests/test_datasets/p_predsol_gpr.npy").T
param = np.array([[-0.5, -0.5], [0.5, -0.5], [0.5, 0.5], [-0.5, 0.5]])


def test_constructor():
pod = POD()
import torch

rbf = RBF()
#rbf = ANN([10, 10], function=torch.nn.Softplus(), stop_training=[1000])
# rbf = ANN([10, 10], function=torch.nn.Softplus(), stop_training=[1000])
db = Database(param, snapshots.T)
# rom = ROM(db, pod, rbf, plugins=[DatabaseScaler(StandardScaler(), 'full', 'snapshots')])
rom = ROM(db, pod, rbf, plugins=[
DatabaseScaler(StandardScaler(), 'reduced', 'parameters'),
DatabaseScaler(StandardScaler(), 'reduced', 'snapshots')
])
rom = ROM(
db,
pod,
rbf,
plugins=[
DatabaseScaler(StandardScaler(), "reduced", "parameters"),
DatabaseScaler(StandardScaler(), "reduced", "snapshots"),
],
)
rom.fit()
assert rom is not None


def test_scaler_reduced_snapshots():
"""Test that StandardScaler on reduced snapshots produces mean=0 and std=1"""
pod = POD()
rbf = RBF()
db = Database(param, snapshots.T)
rom = ROM(
db,
pod,
rbf,
plugins=[DatabaseScaler(StandardScaler(), "reduced", "snapshots")],
)
rom.fit()

# Check that the scaled reduced snapshots have mean ≈ 0 and std ≈ 1
scaled_snapshots = rom.train_reduced_database.snapshots_matrix
np.testing.assert_allclose(np.mean(scaled_snapshots, axis=0), 0, atol=1e-7)
np.testing.assert_allclose(np.std(scaled_snapshots, axis=0), 1, atol=1e-7)


def test_scaler_reduced_parameters():
"""Test that StandardScaler on reduced parameters produces mean=0 and std=1"""
pod = POD()
rbf = RBF()
db = Database(param, snapshots.T)
rom = ROM(
db,
pod,
rbf,
plugins=[DatabaseScaler(StandardScaler(), "reduced", "parameters")],
)
rom.fit()

# Check that the scaled reduced parameters have mean ≈ 0 and std ≈ 1
scaled_params = rom.train_reduced_database.parameters_matrix
np.testing.assert_allclose(np.mean(scaled_params, axis=0), 0, atol=1e-7)
np.testing.assert_allclose(np.std(scaled_params, axis=0), 1, atol=1e-7)


def test_scaler_full_snapshots():
"""Test that StandardScaler on full snapshots produces mean=0 and std=1"""
pod = POD()
rbf = RBF()
db = Database(param, snapshots.T)
rom = ROM(
db,
pod,
rbf,
plugins=[DatabaseScaler(StandardScaler(), "full", "snapshots")],
)
rom.fit()



# Check that the scaled full snapshots have mean ≈ 0 and std ≈ 1
scaled_snapshots = rom.train_full_database.snapshots_matrix
np.testing.assert_allclose(np.mean(scaled_snapshots, axis=0), 0, atol=2e-6)
np.testing.assert_allclose(np.std(scaled_snapshots, axis=0), 1, atol=2e-6)


def test_scaler_full_parameters():
"""Test that StandardScaler on full parameters produces mean=0 and std=1"""
pod = POD()
rbf = RBF()
db = Database(param, snapshots.T)
rom = ROM(
db,
pod,
rbf,
plugins=[DatabaseScaler(StandardScaler(), "full", "parameters")],
)
rom.fit()

# Check that the scaled full parameters have mean ≈ 0 and std ≈ 1
scaled_params = rom.train_full_database.parameters_matrix
np.testing.assert_allclose(np.mean(scaled_params, axis=0), 0, atol=2e-6)
np.testing.assert_allclose(np.std(scaled_params, axis=0), 1, atol=2e-6)


def test_values():
pod = POD()
rbf = RBF()
db = Database(param, snapshots.T)
rom = ROM(db, pod, rbf, plugins=[
DatabaseScaler(StandardScaler(), 'reduced', 'snapshots'),
DatabaseScaler(StandardScaler(), 'full', 'parameters')
])
rom = ROM(
db,
pod,
rbf,
plugins=[
DatabaseScaler(StandardScaler(), "reduced", "snapshots"),
DatabaseScaler(StandardScaler(), "full", "parameters"),
],
)
rom.fit()
test_param = param[2]
truth_sol = db.snapshots_matrix[2]
predicted_sol = rom.predict(test_param)[0]
np.testing.assert_allclose(predicted_sol, truth_sol,
rtol=1e-5, atol=1e-5)

np.testing.assert_allclose(predicted_sol, truth_sol, rtol=1e-5, atol=1e-5)