Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def __init__(self, A, *, P=None, **kwargs):

test, trial = A.arguments()
self.x = Function(trial.function_space())
self.b = Cofunction(test.function_space().dual())

problem = LinearVariationalProblem(A, self.b, self.x, aP=P,
problem = LinearVariationalProblem(A, 0, self.x, aP=P,
form_compiler_parameters=A.form_compiler_parameters,
constant_jacobian=True)
super().__init__(problem, **kwargs)
Expand Down Expand Up @@ -76,12 +75,10 @@ def solve(self, x, b):

# When solving `Ax = b`, with A: V x U -> R, or equivalently A: V -> U*,
# we need to make sure that x and b belong to V and U*, respectively.
if x.function_space() != self.x.function_space():
raise ValueError(f"x must be a Function in {self.x.function_space()}.")
if b.function_space() != self.b.function_space():
raise ValueError(f"b must be a Cofunction in {self.b.function_space()}.")
test, trial = self.A.arguments()
if x.function_space() != test.function_space():
raise ValueError(f"x must be a Function in {test.function_space()}.")
if b.function_space() != trial.function_space().dual():
raise ValueError(f"b must be a Cofunction in {trial.function_space().dual()}.")

self.x.assign(x)
self.b.assign(b)
super().solve()
x.assign(self.x)
super().solve(x=x, b=b)
47 changes: 39 additions & 8 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import ufl
from typing import Tuple
from itertools import chain
from contextlib import ExitStack
from types import MappingProxyType
from petsctools import OptionsManager, flatten_parameters

from firedrake import dmhooks, slate, solving, solving_utils, ufl_expr, utils
from firedrake.petsc import PETSc, DEFAULT_KSP_PARAMETERS, DEFAULT_SNES_PARAMETERS
from firedrake.cofunction import Cofunction
from firedrake.function import Function
from firedrake.interpolation import interpolate
from firedrake.matrix import MatrixBase
Expand Down Expand Up @@ -302,6 +304,7 @@ def update_diffusivity(current_solution):

self._ctx = ctx
self._work = problem.u_restrict.dof_dset.layout_vec.duplicate()
self._work_cofunction = Function(problem.u_restrict.function_space().dual())
self.snes.setDM(problem.dm)

ctx.set_function(self.snes)
Expand Down Expand Up @@ -340,17 +343,30 @@ def set_transfer_manager(self, manager):

@PETSc.Log.EventDecorator()
@NonlinearVariationalSolverMixin._ad_annotate_solve
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adjoint will certainly fail if we tape solver.solve(x=x, b=b), this needs more work

def solve(self, bounds=None):
def solve(self,
bounds: Tuple[Function, Function] | None = None,
x: Function | None = None,
b: Cofunction | None = None):
r"""Solve the variational problem.

:arg bounds: Optional bounds on the solution (lower, upper).
``lower`` and ``upper`` must both be
:class:`~.Function`\s.
Parameters
----------

.. note::
bounds
Optional bounds on the solution (lower, upper).
x
Optional solution buffer with the initial guess on
entry and the converged solution on exit.
b
Optional RHS source term. This enables solving
F == b.

Notes
-----

If bounds are provided the ``snes_type`` must be set to
``vinewtonssls`` or ``vinewtonrsls``.

"""
# Make sure the DM has this solver's callback functions
self._ctx.set_function(self.snes)
Expand All @@ -372,8 +388,17 @@ def solve(self, bounds=None):
problem_dms.append(dm)
problem_dms.append(solution_dm)

if problem.restrict:
# Transfer the initial guess into the RestrictedFunctionSpace
# Transfer the rhs into the RestrictedFunctionSpace
if b is not None:
b = self._work_cofunction.assign(b)
# Zero bc nodes on the rhs
for bc in problem.dirichlet_bcs():
bc.zero(b)

# Transfer the initial guess into the RestrictedFunctionSpace
if x is not None:
problem.u_restrict.assign(x)
elif problem.restrict:
problem.u_restrict.assign(problem.u)

if self._ctx.pre_apply_bcs:
Expand All @@ -395,11 +420,17 @@ def solve(self, bounds=None):
[dmhooks.add_hooks(dm, self, appctx=self._ctx) for dm in problem_dms],
self._transfer_operators):
stack.enter_context(ctx)
self.snes.solve(None, work)
if b is not None:
with b.dat.vec as bvec:
self.snes.solve(bvec, work)
else:
self.snes.solve(None, work)
work.copy(u)
self._setup = True
if problem.restrict:
problem.u.assign(problem.u_restrict)
if x is not None:
x.assign(problem.u)
solving_utils.check_snes_convergence(self.snes)

# Grab the comm associated with the `_problem` and call PETSc's garbage cleanup routine
Expand Down
Loading