Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __init__(self, expr: Interpolate):
# Delay calling .unique() because MixedInterpolator is fine with MeshSequence
self.target_mesh = self.target_space.mesh()
"""The domain we are interpolating into."""
self.source_mesh = extract_unique_domain(operand) or self.target_mesh
self.source_mesh = extract_unique_domain(operand, expand_mesh_sequence=False) or self.target_mesh
"""The domain we are interpolating from."""

# Interpolation options
Expand Down Expand Up @@ -434,6 +434,7 @@ class CrossMeshInterpolator(Interpolator):
def __init__(self, expr: Interpolate):
super().__init__(expr)
self.target_mesh = self.target_mesh.unique()
self.source_mesh = self.source_mesh.unique()
if self.access and self.access != op2.WRITE:
raise NotImplementedError(
"Access other than op2.WRITE not implemented for cross-mesh interpolation."
Expand Down Expand Up @@ -616,6 +617,7 @@ class SameMeshInterpolator(Interpolator):
def __init__(self, expr):
super().__init__(expr)
self.target_mesh = self.target_mesh.unique()
self.source_mesh = self.source_mesh.unique()
subset = self.subset
if subset is None:
target = self.target_mesh.topology
Expand Down Expand Up @@ -1697,7 +1699,7 @@ def _build_aij(

def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None):
mat_type = mat_type or "aij"
sub_mat_type = sub_mat_type or "baij"
sub_mat_type = sub_mat_type or "aij"
Isub = self._get_sub_interpolators(bcs=bcs)
V_dest = self.ufl_interpolate.function_space() or self.target_space
f = tensor or Function(V_dest)
Expand Down
47 changes: 43 additions & 4 deletions tests/firedrake/regression/test_interpolate_cross_mesh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from firedrake import *
from firedrake.petsc import DEFAULT_PARTITIONER
from firedrake.ufl_expr import extract_unique_domain
from firedrake.mesh import Mesh, plex_from_cell_list
from firedrake.formmanipulation import split_form
import numpy as np
import pytest
from ufl import product
Expand Down Expand Up @@ -613,8 +615,8 @@ def test_line_integral():
# Create a 1D line mesh in 2D from (0, 0) to (1, 1) with 1 cell
cells = np.asarray([[0, 1]])
vertex_coords = np.asarray([[0.0, 0.0], [1.0, 1.0]])
plex = mesh.plex_from_cell_list(1, cells, vertex_coords, comm=m.comm)
line = mesh.Mesh(plex, dim=2)
plex = plex_from_cell_list(1, cells, vertex_coords, comm=m.comm)
line = Mesh(plex, dim=2)
x, y = SpatialCoordinate(line)
V_line = FunctionSpace(line, "CG", 2)
f_line = Function(V_line).interpolate(x * y)
Expand All @@ -623,8 +625,8 @@ def test_line_integral():
# Create a 1D line around the unit square (2D) with 4 cells
cells = np.asarray([[0, 1], [1, 2], [2, 3], [3, 0]])
vertex_coords = np.asarray([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]])
plex = mesh.plex_from_cell_list(1, cells, vertex_coords, comm=m.comm)
line_square = mesh.Mesh(plex, dim=2)
plex = plex_from_cell_list(1, cells, vertex_coords, comm=m.comm)
line_square = Mesh(plex, dim=2)
x, y = SpatialCoordinate(line_square)
V_line_square = FunctionSpace(line_square, "CG", 2)
f_line_square = Function(V_line_square).interpolate(x * y)
Expand Down Expand Up @@ -750,3 +752,40 @@ def test_interpolate_cross_mesh_interval(periodic):
f_dest = Function(V_dest).interpolate(f_src)
x_dest, = SpatialCoordinate(m_dest)
assert abs(assemble((f_dest - (-(x_dest - .5) ** 2)) ** 2 * dx)) < 1.e-16


def test_mixed_interpolator_cross_mesh():
# Tests assembly of mixed interpolator across meshes
mesh1 = UnitSquareMesh(4, 4)
mesh2 = UnitSquareMesh(3, 3, quadrilateral=True)
mesh3 = UnitDiskMesh(2)
mesh4 = UnitTriangleMesh(3)
V1 = FunctionSpace(mesh1, "CG", 1)
V2 = FunctionSpace(mesh2, "CG", 2)
V3 = FunctionSpace(mesh3, "CG", 3)
V4 = FunctionSpace(mesh4, "CG", 4)

W = V1 * V2
U = V3 * V4

w = TrialFunction(W)
w0, w1 = split(w)
expr = as_vector([w0 + w1, w0 + w1])
mixed_interp = interpolate(expr, U, allow_missing_dofs=True) # Interpolating from W to U

# The block matrix structure is
# | V1 -> V3 V2 -> V3 |
# | V1 -> V4 V2 -> V4 |

res = assemble(mixed_interp, mat_type="nest")
assert isinstance(res, AssembledMatrix)
assert res.petscmat.type == "nest"

split_interp = dict(split_form(mixed_interp))

for i in range(2):
for j in range(2):
interp_ij = split_interp[(i, j)]
assert isinstance(interp_ij, Interpolate)
res_block = assemble(interpolate(TrialFunction(W.sub(j)), U.sub(i), allow_missing_dofs=True))
assert np.allclose(res.petscmat.getNestSubMatrix(i, j)[:, :], res_block.petscmat[:, :])
4 changes: 2 additions & 2 deletions tests/firedrake/regression/test_interpolator_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_mixed_same_mesh_mattype(value_shape, mat_type, sub_mat_type):
# Always seqaij for scalar
assert sub_mat.type == "seqaij"
else:
# matnest sub_mat_type defaults to baij
assert sub_mat.type == "seq" + (sub_mat_type if sub_mat_type else "baij")
# matnest sub_mat_type defaults to aij
assert sub_mat.type == "seq" + (sub_mat_type if sub_mat_type else "aij")

with pytest.raises(NotImplementedError):
assemble(interp, mat_type="baij")
Expand Down
Loading