Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions demos/full_waveform_inversion/full_waveform_inversion.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,39 +257,51 @@ Next, the FWI problem is executed with the following steps:
:align: center


To have the step 4, we need first to tape the forward problem. That is done by calling::
To have the step 4, we need first to tape the forward problem.
That is done by calling :func:`~.pyadjoint.continue_annotation`.

from firedrake.adjoint import *
continue_annotation()
get_working_tape().progress_bar = ProgressBar

**Steps 2-3**: Solve the wave equation and compute the functional::
**Steps 2-3**: Solve the wave equation and compute the functional.
We create a ``ReducedFunctional`` for each source, which for our
case means one per ensemble member. Creating a ``ReducedFunctional``
per component that we are parallelising over (i.e. per source) -
rather than creating one per ensemble member - we can change
the ensemble parallel partition with minimal changes to the code.::

from firedrake.adjoint import *
Comment on lines +264 to +271
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better 🙃


f = Cofunction(V.dual()) # Wave equation forcing term.
solver, u_np1, u_n, u_nm1 = wave_equation_solver(c_guess, f, dt, V)
interpolate_receivers = interpolate(u_np1, V_r)
J_val = 0.0
for step in range(total_steps):
f.assign(ricker_wavelet(step * dt, frequency_peak) * q_s)
solver.solve()
u_nm1.assign(u_n)
u_n.assign(u_np1)
guess_receiver = assemble(interpolate_receivers)
misfit = guess_receiver - true_data_receivers[step]
J_val += 0.5 * assemble(inner(misfit, misfit) * dx)

We now instantiate :class:`~.EnsembleReducedFunctional`::

J_hat = EnsembleReducedFunctional(J_val,
Control(c_guess, riesz_map="l2"),
my_ensemble)
continue_annotation()
J_val = 0.0
with set_working_tape() as tape:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is with set_working_tape() as tape: better here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes it really clear which bit of the code is being recorded on that tape, and it makes sure that the ReducedFunctional for each $J_i$ has a different tape.

for step in range(total_steps):
f.assign(ricker_wavelet(step * dt, frequency_peak) * q_s)
solver.solve()
u_nm1.assign(u_n)
u_n.assign(u_np1)
guess_receiver = assemble(interpolate_receivers)
misfit = guess_receiver - true_data_receivers[step]
J_val += 0.5 * assemble(inner(misfit, misfit) * dx)

control = Control(c_guess)
Jhat_local = ReducedFunctional(J_val, control, tape=tape)
tape.progress_bar = ProgressBar
pause_annotation()

We now instantiate :class:`~.adjoint.ensemble_reduced_functional.EnsembleReducedFunctional`
with the local ``ReducedFunctional`` for each source::

J_hat = EnsembleReducedFunctional(J_val, control, Jhat_local, my_ensemble)

which enables us to recompute :math:`J` and its gradient :math:`\nabla_{\mathtt{c\_guess}} J`,
where the :math:`J_s` and its gradients :math:`\nabla_{\mathtt{c\_guess}} J_s` are computed in parallel
based on the ``my_ensemble`` configuration.


**Steps 4-6**: The instance of the :class:`~.EnsembleReducedFunctional`, named ``J_hat``,
**Steps 4-6**: The instance of the :class:`~.adjoint.ensemble_reduced_functional.EnsembleReducedFunctional`, named ``J_hat``,
is then passed as an argument to the ``minimize`` function. The default ``minimize`` function
uses ``scipy.minimize``, and wraps the ``ReducedFunctional`` in a ``ReducedFunctionalNumPy``
that handles transferring data between Firedrake and numpy data structures. However, because
Expand Down
6 changes: 5 additions & 1 deletion firedrake/adjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
from firedrake.adjoint.ufl_constraints import ( # noqa: F401
UFLInequalityConstraint, UFLEqualityConstraint
)
from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401
from firedrake.adjoint.ensemble_adjvec import EnsembleAdjVec # noqa F401
from firedrake.adjoint.ensemble_reduced_functional import ( # noqa F401
EnsembleBcastReducedFunctional, EnsembleReduceReducedFunctional,
EnsembleTransformReducedFunctional, EnsembleAllgatherReducedFunctional,
EnsembleReducedFunctional)
from firedrake.adjoint.transformed_functional import L2RieszMap, L2TransformedFunctional # noqa: F401
from firedrake.adjoint.covariance_operator import ( # noqa F401
WhiteNoiseGenerator, AutoregressiveCovariance,
Expand Down
121 changes: 121 additions & 0 deletions firedrake/adjoint/ensemble_adjvec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from functools import cached_property
from pyadjoint.overloaded_type import OverloadedType
from pyadjoint.adjfloat import AdjFloat
from firedrake.ensemble import Ensemble
from firedrake.adjoint_utils.checkpointing import disk_checkpointing


class EnsembleAdjVec(OverloadedType):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not EnsembleAdjFloat?

"""
A vector of :class:`pyadjoint.AdjFloat` distributed
over an :class:`.Ensemble`.

Analagous to the :class:`.EnsembleFunction` and
:class:`.EnsembleCofunction` types but for :class:`~pyadjoint.AdjFloat`.

Implements basic :class:`pyadjoint.OverloadedType` functionality
to be used as a :class:`pyadjoint.Control` or functional for the
:class:`~.ensemble_reduced_functional.EnsembleReducedFunctional` types.

Parameters
----------
subvec :
The local part of the vector.
ensemble :
The :class:`.Ensemble` communicator.

See Also
--------
:class:`~.Ensemble`
:class:`~.EnsembleFunction`
:class:`~.EnsembleCofunction`
:class:`~.EnsembleReducedFunctional`
"""

def __init__(self, subvec: list[AdjFloat], ensemble: Ensemble):
if not isinstance(ensemble, Ensemble):
raise TypeError(
f"EnsembleAdjVec needs an Ensemble, not a {type(ensemble).__name__}")
if not all(isinstance(v, (AdjFloat, float)) for v in subvec):
raise TypeError(
f"EnsembleAdjVec must be instantiated with a list of AdjFloats, not {subvec}")
self._subvec = [AdjFloat(x) for x in subvec]
self.ensemble = ensemble
OverloadedType.__init__(self)

@property
def subvec(self) -> list[AdjFloat]:
"""The part of the vector on the local spatial comm."""
return self._subvec

@cached_property
def local_size(self) -> int:
"""The length of the part of the vector on the local spatial comm."""
return len(self._subvec)

@cached_property
def global_size(self) -> int:
"""The global length of vector."""
return self.ensemble.allreduce(self.local_size)

def _ad_init_zero(self, dual: bool = False) -> "EnsembleAdjVec":
return type(self)(
[v._ad_init_zero(dual=dual) for v in self.subvec],
self.ensemble)

def _ad_dot(self, other: OverloadedType) -> float:
local_dot = sum(s._ad_dot(o)
for s, o in zip(self.subvec, other.subvec))
global_dot = self.ensemble.ensemble_comm.allreduce(local_dot)
return global_dot

def _ad_add(self, other) -> "EnsembleAdjVec":
return EnsembleAdjVec(
[s._ad_add(o) for s, o in zip(self.subvec, other.subvec)],
ensemble=self.ensemble)

def _ad_mul(self, other) -> "EnsembleAdjVec":
return EnsembleAdjVec(
[s._ad_mul(o) for s, o in zip(self.subvec,
self._maybe_scalar(other))],
ensemble=self.ensemble)

def _ad_iadd(self, other) -> "EnsembleAdjVec":
for s, o in zip(self.subvec, other.subvec):
s._ad_iadd(o)
return self

def _ad_imul(self, other) -> "EnsembleAdjVec":
for s, o in zip(self.subvec, self._maybe_scalar(other)):
s._ad_imul(o)
return self

def _maybe_scalar(self, val):
if isinstance(val, EnsembleAdjVec):
return val.subvec
else:
return [val for _ in self.subvec]

def _ad_copy(self) -> "EnsembleAdjVec":
return EnsembleAdjVec(
[v._ad_copy() for v in self.subvec],
ensemble=self.ensemble)

def _ad_convert_riesz(self, value, riesz_map=None) -> "EnsembleAdjVec":
return EnsembleAdjVec(
[s._ad_convert_riesz(v, riesz_map=riesz_map)
for s, v in zip(self.subvec, self._maybe_scalar(value))],
ensemble=self.ensemble)

def _ad_create_checkpoint(self):
if disk_checkpointing():
raise NotImplementedError(
f"Disk checkpointing not implemented for {type(self).__name__}")
else:
return self._ad_copy()

def _ad_restore_at_checkpoint(self, checkpoint):
if type(checkpoint) is type(self):
return checkpoint
raise NotImplementedError(
f"Disk checkpointing not implemented for {type(self).__name__}")
Loading
Loading