Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions simpeg/directives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
PairedBetaEstimate_ByEig,
PairedBetaSchedule,
MovingAndMultiTargetStopping,
ScaleMaximumDerivatives,
)

### Deprecated class
Expand Down
36 changes: 35 additions & 1 deletion simpeg/directives/_sim_directives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from ..regularization import BaseSimilarityMeasure
from ..regularization import BaseSimilarityMeasure, CrossGradient
from ..utils import eigenvalue_by_power_iteration
from ..optimization import IterationPrinters, StoppingCriteria
from ._directives import InversionDirective, SaveOutputEveryIteration
Expand Down Expand Up @@ -377,3 +377,37 @@ def endIter(self):
/ np.linalg.norm(self.opt.x_last),
)
self.opt.stopNextIteration = True


class ScaleMaximumDerivatives(InversionDirective):
"""
Directive for scaling the components of the regularization
based on the maximum theoretical derivatives of model gradients.
"""

def __init__(self, cross_gradient: CrossGradient, **kwargs):
if not isinstance(cross_gradient, CrossGradient):
raise TypeError("cross_gradient must be a CrossGradient regularization.")

self.cross_gradient = cross_gradient

super().__init__(**kwargs)

def endIter(self):
"""
End of iteration update.
"""
max_deriv = []
for _, wire in self.cross_gradient.wire_map.maps:
component = wire * self.opt.xc
max_deriv.append(
(component.max() - component.min())
/ self.cross_gradient.regularization_mesh.base_length**2.0
)

scale = np.min([max_deriv[0] ** 2.0, max_deriv[1] ** 2.0, np.prod(max_deriv)])
if scale == 0:
return

values = np.full(self.cross_gradient.regularization_mesh.n_cells, scale**-1)
self.cross_gradient.set_weights(max_deriv=values)
Comment thread
domfournier marked this conversation as resolved.
87 changes: 57 additions & 30 deletions simpeg/regularization/cross_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import scipy.sparse as sp

from .base import BaseSimilarityMeasure
from ..utils import validate_type, coterminal
from ..utils import validate_type, coterminal, sdiag


###############################################################################
Expand Down Expand Up @@ -131,7 +131,7 @@ class CrossGradient(BaseSimilarityMeasure):
"""

def __init__(
self, mesh, wire_map, approx_hessian=True, units=["metric", "metric"], **kwargs
self, mesh, wire_map, approx_hessian=True, units=("metric", "metric"), **kwargs
):

Comment thread
domfournier marked this conversation as resolved.
super().__init__(mesh, wire_map=wire_map, units=units, **kwargs)
Expand All @@ -142,7 +142,29 @@ def __init__(
if regmesh.mesh.dim not in (2, 3):
raise ValueError("Cross-Gradient is only defined for 2D or 3D")
self._G = regmesh.cell_gradient
self._Av = sp.diags(np.sqrt(regmesh.vol)) * regmesh.average_face_to_cell
self.set_weights(volume=self.regularization_mesh.vol)

@property
def W(self) -> sp.csr_matrix:
r"""Weighting matrix.

Returns the weighting matrix for the objective function. To see how the
weighting matrix is constructed, see the *Notes* section for the
:class:`SmoothnessFirstOrder` regularization class.

Returns
-------
scipy.sparse.csr_matrix
The weighting matrix applied in the objective function.
"""
if getattr(self, "_W", None) is None:
weights = 1.0
for values in self._weights.values():
weights *= values
self._W = (
sdiag(weights**0.5) * self.regularization_mesh.average_face_to_cell
)
return self._W

@property
def approx_hessian(self):
Expand Down Expand Up @@ -236,7 +258,7 @@ def _model_gradients(self, models):
"""
gradients = []

for unit, (name, wire) in zip(self.units, self.wire_map.maps):
for unit, (_, wire) in zip(self.units, self.wire_map.maps):
model = wire * models
if unit == "radian":
gradient = []
Expand Down Expand Up @@ -276,12 +298,11 @@ def __call__(self, model):
float
The regularization function evaluated for the model provided.
"""

Av = self._Av

g_m1, g_m2 = self._model_gradients(model)

return np.sum((Av @ g_m1**2) * (Av @ g_m2**2) - (Av @ (g_m1 * g_m2)) ** 2)
return np.sum(
(self.W @ g_m1**2) * (self.W @ g_m2**2) - (self.W @ (g_m1 * g_m2)) ** 2
)

def deriv(self, model):
r"""Gradient of the regularization function evaluated for the model provided.
Expand Down Expand Up @@ -311,17 +332,16 @@ def deriv(self, model):
(n_param, ) numpy.ndarray
Gradient of the regularization function evaluated for the model provided.
"""
Av = self._Av
G = self._G
g_m1, g_m2 = self._model_gradients(model)

return self.wire_map.deriv(model).T * (
2
* np.r_[
(((Av @ g_m2**2) @ Av) * g_m1) @ G
- (((Av @ (g_m1 * g_m2)) @ Av) * g_m2) @ G,
(((Av @ g_m1**2) @ Av) * g_m2) @ G
- (((Av @ (g_m1 * g_m2)) @ Av) * g_m1) @ G,
(((self.W @ g_m2**2) @ self.W) * g_m1) @ G
- (((self.W @ (g_m1 * g_m2)) @ self.W) * g_m2) @ G,
(((self.W @ g_m1**2) @ self.W) * g_m2) @ G
- (((self.W @ (g_m1 * g_m2)) @ self.W) * g_m1) @ G,
]
) # factor of 2 from derviative of | grad m1 x grad m2 | ^2

Expand Down Expand Up @@ -366,27 +386,26 @@ def deriv2(self, model, v=None):
for the models provided is returned. If *v* is not ``None``,
the Hessian multiplied by the vector provided is returned.
"""
Av = self._Av
G = self._G

g_m1, g_m2 = self._model_gradients(model)

d11_mid = Av.T @ (Av @ g_m2**2)
d12_mid = -(Av.T @ (Av @ (g_m1 * g_m2)))
d22_mid = Av.T @ (Av @ g_m1**2)
d11_mid = self.W.T @ (self.W @ g_m2**2)
d12_mid = -(self.W.T @ (self.W @ (g_m1 * g_m2)))
d22_mid = self.W.T @ (self.W @ g_m1**2)

if v is None:
D11_mid = sp.diags(d11_mid)
D12_mid = sp.diags(d12_mid)
D22_mid = sp.diags(d22_mid)
if not self.approx_hessian:
D11_mid = D11_mid - sp.diags(g_m2) @ Av.T @ Av @ sp.diags(g_m2)
D11_mid = D11_mid - sp.diags(g_m2) @ self.W.T @ self.W @ sp.diags(g_m2)
D12_mid = (
D12_mid
+ 2 * sp.diags(g_m1) @ Av.T @ Av @ sp.diags(g_m2)
- sp.diags(g_m2) @ Av.T @ Av @ sp.diags(g_m1)
+ 2 * sp.diags(g_m1) @ self.W.T @ self.W @ sp.diags(g_m2)
- sp.diags(g_m2) @ self.W.T @ self.W @ sp.diags(g_m1)
)
D22_mid = D22_mid - sp.diags(g_m1) @ Av.T @ Av @ sp.diags(g_m1)
D22_mid = D22_mid - sp.diags(g_m1) @ self.W.T @ self.W @ sp.diags(g_m1)
D11 = G.T @ D11_mid @ G
D12 = G.T @ D12_mid @ G
D22 = G.T @ D22_mid @ G
Expand All @@ -407,22 +426,26 @@ def deriv2(self, model, v=None):
p2 = G.T @ (d12_mid * Gv1 + d22_mid * Gv2)
if not self.approx_hessian:
p1 += G.T @ (
-g_m2 * (Av.T @ (Av @ (g_m2 * Gv1))) # d11*v1 full addition
+ 2 * g_m1 * (Av.T @ (Av @ (g_m2 * Gv2))) # d12*v2 full addition
- g_m2 * (Av.T @ (Av @ (g_m1 * Gv2))) # d12*v2 continued
-g_m2 * (self.W.T @ (self.W @ (g_m2 * Gv1))) # d11*v1 full addition
+ 2
* g_m1
* (self.W.T @ (self.W @ (g_m2 * Gv2))) # d12*v2 full addition
- g_m2 * (self.W.T @ (self.W @ (g_m1 * Gv2))) # d12*v2 continued
)

p2 += G.T @ (
-g_m1 * (Av.T @ (Av @ (g_m1 * Gv2))) # d22*v2 full addition
+ 2 * g_m2 * (Av.T @ (Av @ (g_m1 * Gv1))) # d12.T*v1 full addition
- g_m1 * (Av.T @ (Av @ (g_m2 * Gv1))) # d12.T*v1 fcontinued
-g_m1 * (self.W.T @ (self.W @ (g_m1 * Gv2))) # d22*v2 full addition
+ 2
* g_m2
* (self.W.T @ (self.W @ (g_m1 * Gv1))) # d12.T*v1 full addition
- g_m1 * (self.W.T @ (self.W @ (g_m2 * Gv1))) # d12.T*v1 fcontinued
)
return (
2 * self.wire_map.deriv(model).T * np.r_[p1, p2]
) # factor of 2 from derviative of | grad m1 x grad m2 | ^2

@property
def units(self) -> list[str] | None:
def units(self) -> tuple[str]:
"""Units for the model parameters.

Some regularization classes behave differently depending on the units; e.g. 'radian'.
Expand All @@ -435,14 +458,18 @@ def units(self) -> list[str] | None:
return self._units

@units.setter
def units(self, units: list[str] | None):
def units(self, units: tuple[str] | None):
if (
units is not None
and not isinstance(units, list)
and not isinstance(units, list | tuple)
and not all(isinstance(u, str) for u in units)
):
raise TypeError(
f"'units' must be None or a list of str. "
f"Value of type {type(units)} provided."
)

if units is None:
units = ("metric", "metric")

self._units = units
Loading