Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import abc

from functools import cached_property, partial
from functools import cached_property, partial, singledispatchmethod
from typing import Hashable, Literal, Callable, Iterable
from dataclasses import asdict, dataclass
from numbers import Number
Expand All @@ -15,6 +15,7 @@
from ufl.constantvalue import zero, as_ufl
from ufl.form import ZeroBaseForm, BaseForm
from ufl.core.interpolate import Interpolate as UFLInterpolate
from ufl.corealg.dag_traverser import DAGTraverser

from pyop2 import op2
from pyop2.caching import memory_and_disk_cache
Expand Down Expand Up @@ -100,6 +101,25 @@ class InterpolateOptions:
default_missing_val: float | None = None


class NestedInterpolateLowerer(DAGTraverser):
"""Lower nested interpolate nodes to assembled coefficients."""

@singledispatchmethod
def process(self, o: Expr) -> Expr:
return super().process(o)

@process.register(Expr)
@process.register(BaseForm)
def _(self, o: Expr | BaseForm) -> Expr | BaseForm:
return self.reuse_if_untouched(o)

@process.register(UFLInterpolate)
@DAGTraverser.postorder
def _(self, o: UFLInterpolate, operand: Expr) -> Expr:
from firedrake.assemble import assemble
return as_ufl(assemble(o._ufl_expr_reconstruct_(operand)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't BaseFormAssembler the intended way of traversing the DAG for assemble?

Comment on lines +116 to +120
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of eagerly assembling each nested Interpolate node, I suggest we replace it with a placeholder Function and have the DAG traverser built a mapping {Function: Interpolator}. This way we delay populating the Functions with values until the outer interpolation has to be assembled . This ensures that the parloops corresponding to the inner interpolations always get on the stack and each is executed before the immediate outer interpolation.

Suggested change
@process.register(UFLInterpolate)
@DAGTraverser.postorder
def _(self, o: UFLInterpolate, operand: Expr) -> Expr:
from firedrake.assemble import assemble
return as_ufl(assemble(o._ufl_expr_reconstruct_(operand)))
@process.register(UFLInterpolate)
@DAGTraverser.postorder
def _(self, o: UFLInterpolate, operand: Expr) -> Expr:
inner_node = o._ufl_expr_reconstruct_(operand)
fn = Function(o.ufl_function_space())
self.subs[fn] = get_interpolator(inner_node)
return as_ufl(fn)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mapping could then be used in a dedicated Interpolator sub-class that explicitly chains the composition of callables:

class CompositeInterpolator(Interpolator):
    """
    An interpolator for expressions containing nested interpolations 
    (possibly defined across different meshes).
    """

    def __init__(self, outer_expr, subs):
        super().__init__(outer_expr)
        self.subs = subs # {Function: Interpolator} returned by NestedInterpolateLowerer
        self._outer = get_interpolator(outer_expr)

    def _get_callable(self, tensor=None, bcs=None, **kwargs):
        inner_callables = [
            interp._get_callable(tensor=fn)
            for fn, interp in self.subs.items()
        ]
        outer_callable = self._outer._get_callable(tensor=tensor, bcs=bcs)

        def callable():
            for c in inner_callables:
                c()
            return outer_callable()

        return callable

With this we don't need to go through the entire assemble dispatch every single time. Instead, simply executing the callable returned by CompositeInterpolator suffices to ensure everything that's nested gets re-evaluated properly.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this logic should all go in assemble, as it handles compositions of BaseForm more generically.

The current issue is that symbolic Interpolate objects get reconstructed in BaseFormAssembler as we processes them, but the resulting numerical Interpolator does not get cached on the original expression, but on the processed one, which is then thrown away.

See this PR where we added caching for the Interpolator #4827

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But wouldn't assemble preprocess the expression every time it is called? Wouldn't the symbolic processing introduce overhead as opposed to targeting the callables that handle the execution directly?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this logic should all go in assemble, as it handles compositions of BaseForm more generically.

I agree with this.

But wouldn't assemble preprocess the expression every time it is called?

I don't know. But if so we should look to cache it. Why not cache the preprocessed expression on the input expression? Then you can cache interpolators on the preprocessed expression that end up being persistent.

Copy link
Copy Markdown
Contributor

@pbrubeck pbrubeck May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseFormAssembler already deals with both primal Expr/BaseForm, and dual Form/BaseForm nodes. The current BaseFormAssembler implementation could be improved if we turn it into a DAGTraverser that we could dispatch on any ufl type, but that's a separate issue.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessarily a separate issue? This discussion makes it seem a little like infrastructure surgery is required in order to enable us to handle nested interpolates without tremendous hackery.

Copy link
Copy Markdown
Contributor

@achanbour achanbour May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but the resulting numerical Interpolator does not get cached on the original expression, but on the processed one, which is then thrown away.

I am not sure I understand here. Is the processed expression thrown away while the original one gets retained? If so, then making the processed expression persistent would solve the issue as Connor suggests.

Is there a reason why the processed expression gets thrown away?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assemble will just return the numerical Function/Cofunction. Any intermidiate symbolic expression used to arrive to the numerical result will not be returned by assemble. The right thing would be to cache them, or cache the assembler/interpolator on the original symbolic expression.

Copy link
Copy Markdown
Contributor

@pbrubeck pbrubeck May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessarily a separate issue?

The bug can be fixed without restructuring the entire class. And the class restructuring might be done in a way that preserves the bug. The issues are related, but can be dealt with separately.



class Interpolate(UFLInterpolate):

def __init__(self, expr: Expr, V: WithGeometry | BaseForm, **kwargs):
Expand Down Expand Up @@ -161,14 +181,18 @@ def _interpolator(self):
An appropriate :class:`Interpolator` subclass for this
interpolation expression.
"""
operand, = self.ufl_operands
# Check for nested Interpolates first
lowered_operand = NestedInterpolateLowerer()(operand)
if lowered_operand is not operand:
return get_interpolator(self._ufl_expr_reconstruct_(lowered_operand))
Comment on lines +186 to +188
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this approach freeze the nested expression? What if we change the numeric values between different calls to assemble on the same expression?

Comment on lines +185 to +188
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With my earlier suggestion on using CompositeInterpolator:

Suggested change
# Check for nested Interpolates first
lowered_operand = NestedInterpolateLowerer()(operand)
if lowered_operand is not operand:
return get_interpolator(self._ufl_expr_reconstruct_(lowered_operand))
# Check for nested Interpolates first
lowerer = NestedInterpolateLowerer()
lowered_operand = lowerer(operand)
if lowered_operand is not operand:
return CompositeInterpolator(self._ufl_expr_reconstruct_(lowered_operand), lowerer.subs)


arguments = self.arguments()
has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments)
if len(arguments) == 2 and has_mixed_arguments:
return MixedInterpolator(self)

operand, = self.ufl_operands
target_mesh = self.target_space.mesh()

try:
source_mesh = extract_unique_domain(operand) or target_mesh
except ValueError:
Expand Down
42 changes: 42 additions & 0 deletions tests/firedrake/regression/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,45 @@ def test_interpolate_form_mixed():

res3 = assemble(inner(u, q) * dx) # V x W -> R
assert mat_equals(res1, res3)


def test_nested_interpolate_expr_vom():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reproduce the bug with nested interpolate objects on a single mesh?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails when compiling the dual evaluation kernel:

File "/Users/lac224/Coding/work/firedrake-dev/firedrake/tsfc/driver.py", line 346, in compile_expression_dual_evaluation
    evaluation, basis_indices = to_element.dual_evaluation(fn, coordinate_mapping)
                                ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/lac224/Coding/work/firedrake-dev/fiat/finat/tensorfiniteelement.py", line 178, in dual_evaluation
    expr = fn(x)
  File "/Users/lac224/Coding/work/firedrake-dev/firedrake/tsfc/driver.py", line 437, in __call__
    gem_expr, = fem.compile_ufl(self.expression, translation_context, point_sum=False)
                ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/lac224/Coding/work/firedrake-dev/firedrake/tsfc/fem.py", line 854, in compile_ufl
    result = map_expr_dags(context.translator, expressions)
  File "/Users/lac224/Coding/work/firedrake-dev/ufl/ufl/corealg/map_dag.py", line 114, in map_expr_dags
    r = handlers[v._ufl_typecode_](v, *(vcache[u] for u in v.ufl_operands))
  File "/Users/lac224/Coding/work/firedrake-dev/ufl/ufl/corealg/multifunction.py", line 99, in undefined
    raise ValueError(f"No handler defined for {o._ufl_class_.__name__}.")
ValueError: No handler defined for Interpolate.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might indicate that BaseFormAssembler is not recurring on the operand before construction the Interpolator

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a case for adding Interpolate to the form compiler. We would be able to directly generate cell kernels to assemble an aij matrix from inner(Interpolate(grad(u), V2), ...)*dx. In theory we would require to combine compile_dual_evaluation with compile_form

mesh = UnitSquareMesh(3, 3)
x, y = SpatialCoordinate(mesh)
points = np.array([[0.5, 0.5], [0.6, 0.6]])
vom = VertexOnlyMesh(mesh, points)

tfs = TensorFunctionSpace(vom, "DG", 0)
vfs = VectorFunctionSpace(vom, "DG", 0)

expr = as_tensor([[x, 0], [0, y]])
v = Function(vfs).interpolate(as_vector([1.0, 2.0]))
inner_expr = interpolate(expr, tfs) * v

result = assemble(interpolate(inner_expr, vom.coordinates.function_space()))

expected = np.array([[0.5, 1.0], [0.6, 1.2]])
assert np.allclose(result.dat.data_ro.reshape(-1, 2), expected)


@pytest.mark.parallel([1, 3])
def test_nested_interpolate_expr():
mesh1 = UnitSquareMesh(3, 3)
x1, y1 = SpatialCoordinate(mesh1)
mesh2 = UnitSquareMesh(4, 4)
x2, y2 = SpatialCoordinate(mesh2)
mesh3 = UnitSquareMesh(5, 5)
x3, y3 = SpatialCoordinate(mesh3)

tfs2 = TensorFunctionSpace(mesh1, "CG", 1)
tfs3 = TensorFunctionSpace(mesh2, "CG", 2)
expr1 = as_tensor([[x1, 0], [0, y1]])
expr2 = as_tensor([[x2, 0], [0, y2]])

inner_expr = interpolate(expr1, tfs3) * interpolate(expr2, tfs3)
result = assemble(interpolate(inner_expr, tfs2))

res1 = assemble(interpolate(expr1, tfs3))
res2 = assemble(interpolate(expr2, tfs3))
expected = assemble(interpolate(res1 * res2, tfs2))
assert np.allclose(result.dat.data_ro, expected.dat.data_ro)
Loading