-
Notifications
You must be signed in to change notification settings - Fork 191
Leo/nested interpolate #5097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Leo/nested interpolate #5097
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |||||||||||||||||||||||||
| import tempfile | ||||||||||||||||||||||||||
| import abc | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from functools import cached_property, partial | ||||||||||||||||||||||||||
| from functools import cached_property, partial, singledispatchmethod | ||||||||||||||||||||||||||
| from typing import Hashable, Literal, Callable, Iterable | ||||||||||||||||||||||||||
| from dataclasses import asdict, dataclass | ||||||||||||||||||||||||||
| from numbers import Number | ||||||||||||||||||||||||||
|
|
@@ -15,6 +15,7 @@ | |||||||||||||||||||||||||
| from ufl.constantvalue import zero, as_ufl | ||||||||||||||||||||||||||
| from ufl.form import ZeroBaseForm, BaseForm | ||||||||||||||||||||||||||
| from ufl.core.interpolate import Interpolate as UFLInterpolate | ||||||||||||||||||||||||||
| from ufl.corealg.dag_traverser import DAGTraverser | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from pyop2 import op2 | ||||||||||||||||||||||||||
| from pyop2.caching import memory_and_disk_cache | ||||||||||||||||||||||||||
|
|
@@ -100,6 +101,25 @@ class InterpolateOptions: | |||||||||||||||||||||||||
| default_missing_val: float | None = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class NestedInterpolateLowerer(DAGTraverser): | ||||||||||||||||||||||||||
| """Lower nested interpolate nodes to assembled coefficients.""" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @singledispatchmethod | ||||||||||||||||||||||||||
| def process(self, o: Expr) -> Expr: | ||||||||||||||||||||||||||
| return super().process(o) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @process.register(Expr) | ||||||||||||||||||||||||||
| @process.register(BaseForm) | ||||||||||||||||||||||||||
| def _(self, o: Expr | BaseForm) -> Expr | BaseForm: | ||||||||||||||||||||||||||
| return self.reuse_if_untouched(o) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @process.register(UFLInterpolate) | ||||||||||||||||||||||||||
| @DAGTraverser.postorder | ||||||||||||||||||||||||||
| def _(self, o: UFLInterpolate, operand: Expr) -> Expr: | ||||||||||||||||||||||||||
| from firedrake.assemble import assemble | ||||||||||||||||||||||||||
| return as_ufl(assemble(o._ufl_expr_reconstruct_(operand))) | ||||||||||||||||||||||||||
|
Comment on lines
+116
to
+120
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of eagerly assembling each nested
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This mapping could then be used in a dedicated Interpolator sub-class that explicitly chains the composition of callables: class CompositeInterpolator(Interpolator):
"""
An interpolator for expressions containing nested interpolations
(possibly defined across different meshes).
"""
def __init__(self, outer_expr, subs):
super().__init__(outer_expr)
self.subs = subs # {Function: Interpolator} returned by NestedInterpolateLowerer
self._outer = get_interpolator(outer_expr)
def _get_callable(self, tensor=None, bcs=None, **kwargs):
inner_callables = [
interp._get_callable(tensor=fn)
for fn, interp in self.subs.items()
]
outer_callable = self._outer._get_callable(tensor=tensor, bcs=bcs)
def callable():
for c in inner_callables:
c()
return outer_callable()
return callableWith this we don't need to go through the entire
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still think this logic should all go in The current issue is that symbolic See this PR where we added caching for the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But wouldn't
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I agree with this.
I don't know. But if so we should look to cache it. Why not cache the preprocessed expression on the input expression? Then you can cache interpolators on the preprocessed expression that end up being persistent.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BaseFormAssembler already deals with both primal Expr/BaseForm, and dual Form/BaseForm nodes. The current BaseFormAssembler implementation could be improved if we turn it into a DAGTraverser that we could dispatch on any ufl type, but that's a separate issue.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it necessarily a separate issue? This discussion makes it seem a little like infrastructure surgery is required in order to enable us to handle nested interpolates without tremendous hackery.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand here. Is the processed expression thrown away while the original one gets retained? If so, then making the processed expression persistent would solve the issue as Connor suggests. Is there a reason why the processed expression gets thrown away?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The bug can be fixed without restructuring the entire class. And the class restructuring might be done in a way that preserves the bug. The issues are related, but can be dealt with separately. |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class Interpolate(UFLInterpolate): | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __init__(self, expr: Expr, V: WithGeometry | BaseForm, **kwargs): | ||||||||||||||||||||||||||
|
|
@@ -161,14 +181,18 @@ def _interpolator(self): | |||||||||||||||||||||||||
| An appropriate :class:`Interpolator` subclass for this | ||||||||||||||||||||||||||
| interpolation expression. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| operand, = self.ufl_operands | ||||||||||||||||||||||||||
| # Check for nested Interpolates first | ||||||||||||||||||||||||||
| lowered_operand = NestedInterpolateLowerer()(operand) | ||||||||||||||||||||||||||
| if lowered_operand is not operand: | ||||||||||||||||||||||||||
| return get_interpolator(self._ufl_expr_reconstruct_(lowered_operand)) | ||||||||||||||||||||||||||
|
Comment on lines
+186
to
+188
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't this approach freeze the nested expression? What if we change the numeric values between different calls to assemble on the same expression?
Comment on lines
+185
to
+188
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With my earlier suggestion on using
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| arguments = self.arguments() | ||||||||||||||||||||||||||
| has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) | ||||||||||||||||||||||||||
| if len(arguments) == 2 and has_mixed_arguments: | ||||||||||||||||||||||||||
| return MixedInterpolator(self) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| operand, = self.ufl_operands | ||||||||||||||||||||||||||
| target_mesh = self.target_space.mesh() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| source_mesh = extract_unique_domain(operand) or target_mesh | ||||||||||||||||||||||||||
| except ValueError: | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -741,3 +741,45 @@ def test_interpolate_form_mixed(): | |
|
|
||
| res3 = assemble(inner(u, q) * dx) # V x W -> R | ||
| assert mat_equals(res1, res3) | ||
|
|
||
|
|
||
| def test_nested_interpolate_expr_vom(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you reproduce the bug with nested interpolate objects on a single mesh?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It fails when compiling the dual evaluation kernel:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might indicate that BaseFormAssembler is not recurring on the operand before construction the Interpolator
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a case for adding |
||
| mesh = UnitSquareMesh(3, 3) | ||
| x, y = SpatialCoordinate(mesh) | ||
| points = np.array([[0.5, 0.5], [0.6, 0.6]]) | ||
| vom = VertexOnlyMesh(mesh, points) | ||
|
|
||
| tfs = TensorFunctionSpace(vom, "DG", 0) | ||
| vfs = VectorFunctionSpace(vom, "DG", 0) | ||
|
|
||
| expr = as_tensor([[x, 0], [0, y]]) | ||
| v = Function(vfs).interpolate(as_vector([1.0, 2.0])) | ||
| inner_expr = interpolate(expr, tfs) * v | ||
|
|
||
| result = assemble(interpolate(inner_expr, vom.coordinates.function_space())) | ||
|
|
||
| expected = np.array([[0.5, 1.0], [0.6, 1.2]]) | ||
| assert np.allclose(result.dat.data_ro.reshape(-1, 2), expected) | ||
|
|
||
|
|
||
| @pytest.mark.parallel([1, 3]) | ||
| def test_nested_interpolate_expr(): | ||
| mesh1 = UnitSquareMesh(3, 3) | ||
| x1, y1 = SpatialCoordinate(mesh1) | ||
| mesh2 = UnitSquareMesh(4, 4) | ||
| x2, y2 = SpatialCoordinate(mesh2) | ||
| mesh3 = UnitSquareMesh(5, 5) | ||
| x3, y3 = SpatialCoordinate(mesh3) | ||
|
|
||
| tfs2 = TensorFunctionSpace(mesh1, "CG", 1) | ||
| tfs3 = TensorFunctionSpace(mesh2, "CG", 2) | ||
| expr1 = as_tensor([[x1, 0], [0, y1]]) | ||
| expr2 = as_tensor([[x2, 0], [0, y2]]) | ||
|
|
||
| inner_expr = interpolate(expr1, tfs3) * interpolate(expr2, tfs3) | ||
| result = assemble(interpolate(inner_expr, tfs2)) | ||
|
|
||
| res1 = assemble(interpolate(expr1, tfs3)) | ||
| res2 = assemble(interpolate(expr2, tfs3)) | ||
| expected = assemble(interpolate(res1 * res2, tfs2)) | ||
| assert np.allclose(result.dat.data_ro, expected.dat.data_ro) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't
BaseFormAssemblerthe intended way of traversing the DAG forassemble?