-
Notifications
You must be signed in to change notification settings - Fork 189
Refactor EnsembleReducedFunctional
#4965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -257,39 +257,51 @@ Next, the FWI problem is executed with the following steps: | |
| :align: center | ||
|
|
||
|
|
||
| To have the step 4, we need first to tape the forward problem. That is done by calling:: | ||
| To have the step 4, we need first to tape the forward problem. | ||
| That is done by calling :func:`~.pyadjoint.continue_annotation`. | ||
|
|
||
| from firedrake.adjoint import * | ||
| continue_annotation() | ||
| get_working_tape().progress_bar = ProgressBar | ||
|
|
||
| **Steps 2-3**: Solve the wave equation and compute the functional:: | ||
| **Steps 2-3**: Solve the wave equation and compute the functional. | ||
| We create a ``ReducedFunctional`` for each source, which for our | ||
| case means one per ensemble member. Creating a ``ReducedFunctional`` | ||
| per component that we are parallelising over (i.e. per source) - | ||
| rather than creating one per ensemble member - we can change | ||
| the ensemble parallel partition with minimal changes to the code.:: | ||
|
|
||
| from firedrake.adjoint import * | ||
|
|
||
| f = Cofunction(V.dual()) # Wave equation forcing term. | ||
| solver, u_np1, u_n, u_nm1 = wave_equation_solver(c_guess, f, dt, V) | ||
| interpolate_receivers = interpolate(u_np1, V_r) | ||
| J_val = 0.0 | ||
| for step in range(total_steps): | ||
| f.assign(ricker_wavelet(step * dt, frequency_peak) * q_s) | ||
| solver.solve() | ||
| u_nm1.assign(u_n) | ||
| u_n.assign(u_np1) | ||
| guess_receiver = assemble(interpolate_receivers) | ||
| misfit = guess_receiver - true_data_receivers[step] | ||
| J_val += 0.5 * assemble(inner(misfit, misfit) * dx) | ||
|
|
||
| We now instantiate :class:`~.EnsembleReducedFunctional`:: | ||
|
|
||
| J_hat = EnsembleReducedFunctional(J_val, | ||
| Control(c_guess, riesz_map="l2"), | ||
| my_ensemble) | ||
| continue_annotation() | ||
| J_val = 0.0 | ||
| with set_working_tape() as tape: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It makes it really clear which bit of the code is being recorded on that |
||
| for step in range(total_steps): | ||
| f.assign(ricker_wavelet(step * dt, frequency_peak) * q_s) | ||
| solver.solve() | ||
| u_nm1.assign(u_n) | ||
| u_n.assign(u_np1) | ||
| guess_receiver = assemble(interpolate_receivers) | ||
| misfit = guess_receiver - true_data_receivers[step] | ||
| J_val += 0.5 * assemble(inner(misfit, misfit) * dx) | ||
|
|
||
| control = Control(c_guess) | ||
| Jhat_local = ReducedFunctional(J_val, control, tape=tape) | ||
| tape.progress_bar = ProgressBar | ||
| pause_annotation() | ||
|
|
||
| We now instantiate :class:`~.adjoint.ensemble_reduced_functional.EnsembleReducedFunctional` | ||
| with the local ``ReducedFunctional`` for each source:: | ||
|
|
||
| J_hat = EnsembleReducedFunctional(J_val, control, Jhat_local, my_ensemble) | ||
|
|
||
| which enables us to recompute :math:`J` and its gradient :math:`\nabla_{\mathtt{c\_guess}} J`, | ||
| where the :math:`J_s` and its gradients :math:`\nabla_{\mathtt{c\_guess}} J_s` are computed in parallel | ||
| based on the ``my_ensemble`` configuration. | ||
|
|
||
|
|
||
| **Steps 4-6**: The instance of the :class:`~.EnsembleReducedFunctional`, named ``J_hat``, | ||
| **Steps 4-6**: The instance of the :class:`~.adjoint.ensemble_reduced_functional.EnsembleReducedFunctional`, named ``J_hat``, | ||
| is then passed as an argument to the ``minimize`` function. The default ``minimize`` function | ||
| uses ``scipy.minimize``, and wraps the ``ReducedFunctional`` in a ``ReducedFunctionalNumPy`` | ||
| that handles transferring data between Firedrake and numpy data structures. However, because | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| from functools import cached_property | ||
| from pyadjoint.overloaded_type import OverloadedType | ||
| from pyadjoint.adjfloat import AdjFloat | ||
| from firedrake.ensemble import Ensemble | ||
| from firedrake.adjoint_utils.checkpointing import disk_checkpointing | ||
|
|
||
|
|
||
| class EnsembleAdjVec(OverloadedType): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not |
||
| """ | ||
| A vector of :class:`pyadjoint.AdjFloat` distributed | ||
| over an :class:`.Ensemble`. | ||
|
|
||
| Analagous to the :class:`.EnsembleFunction` and | ||
| :class:`.EnsembleCofunction` types but for :class:`~pyadjoint.AdjFloat`. | ||
|
|
||
| Implements basic :class:`pyadjoint.OverloadedType` functionality | ||
| to be used as a :class:`pyadjoint.Control` or functional for the | ||
| :class:`~.ensemble_reduced_functional.EnsembleReducedFunctional` types. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| subvec : | ||
| The local part of the vector. | ||
| ensemble : | ||
| The :class:`.Ensemble` communicator. | ||
|
|
||
| See Also | ||
| -------- | ||
| :class:`~.Ensemble` | ||
| :class:`~.EnsembleFunction` | ||
| :class:`~.EnsembleCofunction` | ||
| :class:`~.EnsembleReducedFunctional` | ||
| """ | ||
|
|
||
| def __init__(self, subvec: list[AdjFloat], ensemble: Ensemble): | ||
| if not isinstance(ensemble, Ensemble): | ||
| raise TypeError( | ||
| f"EnsembleAdjVec needs an Ensemble, not a {type(ensemble).__name__}") | ||
| if not all(isinstance(v, (AdjFloat, float)) for v in subvec): | ||
| raise TypeError( | ||
| f"EnsembleAdjVec must be instantiated with a list of AdjFloats, not {subvec}") | ||
| self._subvec = [AdjFloat(x) for x in subvec] | ||
| self.ensemble = ensemble | ||
| OverloadedType.__init__(self) | ||
|
|
||
| @property | ||
| def subvec(self) -> list[AdjFloat]: | ||
| """The part of the vector on the local spatial comm.""" | ||
| return self._subvec | ||
|
|
||
| @cached_property | ||
| def local_size(self) -> int: | ||
| """The length of the part of the vector on the local spatial comm.""" | ||
| return len(self._subvec) | ||
|
|
||
| @cached_property | ||
| def global_size(self) -> int: | ||
| """The global length of vector.""" | ||
| return self.ensemble.allreduce(self.local_size) | ||
|
|
||
| def _ad_init_zero(self, dual: bool = False) -> "EnsembleAdjVec": | ||
| return type(self)( | ||
| [v._ad_init_zero(dual=dual) for v in self.subvec], | ||
| self.ensemble) | ||
|
|
||
| def _ad_dot(self, other: OverloadedType) -> float: | ||
| local_dot = sum(s._ad_dot(o) | ||
| for s, o in zip(self.subvec, other.subvec)) | ||
| global_dot = self.ensemble.ensemble_comm.allreduce(local_dot) | ||
| return global_dot | ||
|
|
||
| def _ad_add(self, other) -> "EnsembleAdjVec": | ||
| return EnsembleAdjVec( | ||
| [s._ad_add(o) for s, o in zip(self.subvec, other.subvec)], | ||
| ensemble=self.ensemble) | ||
|
|
||
| def _ad_mul(self, other) -> "EnsembleAdjVec": | ||
| return EnsembleAdjVec( | ||
| [s._ad_mul(o) for s, o in zip(self.subvec, | ||
| self._maybe_scalar(other))], | ||
| ensemble=self.ensemble) | ||
|
|
||
| def _ad_iadd(self, other) -> "EnsembleAdjVec": | ||
| for s, o in zip(self.subvec, other.subvec): | ||
| s._ad_iadd(o) | ||
| return self | ||
|
|
||
| def _ad_imul(self, other) -> "EnsembleAdjVec": | ||
| for s, o in zip(self.subvec, self._maybe_scalar(other)): | ||
| s._ad_imul(o) | ||
| return self | ||
|
|
||
| def _maybe_scalar(self, val): | ||
| if isinstance(val, EnsembleAdjVec): | ||
| return val.subvec | ||
| else: | ||
| return [val for _ in self.subvec] | ||
|
|
||
| def _ad_copy(self) -> "EnsembleAdjVec": | ||
| return EnsembleAdjVec( | ||
| [v._ad_copy() for v in self.subvec], | ||
| ensemble=self.ensemble) | ||
|
|
||
| def _ad_convert_riesz(self, value, riesz_map=None) -> "EnsembleAdjVec": | ||
| return EnsembleAdjVec( | ||
| [s._ad_convert_riesz(v, riesz_map=riesz_map) | ||
| for s, v in zip(self.subvec, self._maybe_scalar(value))], | ||
| ensemble=self.ensemble) | ||
|
|
||
| def _ad_create_checkpoint(self): | ||
| if disk_checkpointing(): | ||
| raise NotImplementedError( | ||
| f"Disk checkpointing not implemented for {type(self).__name__}") | ||
| else: | ||
| return self._ad_copy() | ||
|
|
||
| def _ad_restore_at_checkpoint(self, checkpoint): | ||
| if type(checkpoint) is type(self): | ||
| return checkpoint | ||
| raise NotImplementedError( | ||
| f"Disk checkpointing not implemented for {type(self).__name__}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is better 🙃