Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2203,11 +2203,11 @@ def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray):

def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]:
"""Return mesh periodicity information.

This function returns a 2-tuple of bools per dimension where the first entry indicates
whether the mesh is periodic in that dimension, and the second indicates whether the
mesh is single-cell periodic in that dimension.

"""
cdef:
const PetscReal *maxCell, *L
Expand Down Expand Up @@ -4325,3 +4325,24 @@ def get_dm_cell_types(PETSc.DM dm):
return tuple(
polytope_type_enum for polytope_type_enum, found in enumerate(found_all) if found
)


def intersectIS(PETSc.IS i1, PETSc.IS i2):
"""Return the intersection of two IS objects.

Parameters
----------
i1 : PETSc.IS
The first IS.
i2 : PETSc.IS
The second IS.

Returns
-------
PETSc.IS
A PETSc.IS with the intersection.

"""
cdef PETSc.IS iout = PETSc.IS()
CHKERR(ISIntersect(i1.iset, i2.iset, &iout.iset))
return iout
1 change: 1 addition & 0 deletions firedrake/cython/petschdr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ cdef extern from "petscis.h" nogil:
PetscErrorCode ISLocalToGlobalMappingGetBlockIndices(PETSc.PetscLGMap, const PetscInt**)
PetscErrorCode ISLocalToGlobalMappingRestoreBlockIndices(PETSc.PetscLGMap, const PetscInt**)
PetscErrorCode ISDestroy(PETSc.PetscIS*)
PetscErrorCode ISIntersect(PETSc.PetscIS, PETSc.PetscIS, PETSc.PetscIS*)

cdef extern from "petscsf.h" nogil:
struct PetscSFNode_:
Expand Down
107 changes: 102 additions & 5 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyop2.mpi import (
MPI, COMM_WORLD, temp_internal_comm
)
from functools import cached_property
from functools import cached_property, reduce
from pyop2.utils import as_tuple
import petsctools
from petsctools import OptionsManager, get_external_packages
Expand Down Expand Up @@ -4806,10 +4806,12 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
subdim : int | None
Topological dimension of the submesh.
Defaults to ``mesh.topological_dimension``.
subdomain_id : int | None
subdomain_id : int | Sequence | None
Subdomain ID representing the submesh.
If `None` the submesh will cover the entire domain.
This is useful to obtain a codim-1 submesh over all facets or
If multiple subdomain IDs are provided, their union is taken.
If nested lists of subdomain IDs are provided, their intersection is taken.
Comment thread
pbrubeck marked this conversation as resolved.
If `None` the submesh will cover the entire domain,
this is useful to obtain a codim-1 submesh over all facets or
a submesh over a different communicator.
label_name : str | None
Name of the label to search ``subdomain_id`` in.
Expand Down Expand Up @@ -4852,13 +4854,70 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
ridges to be contained in the quad mesh are shared by at most two
facets to make the quad mesh orientation algorithm work.

Examples
--------
>>> mesh = UnitSquareMesh(4, 4)
>>> x, y = SpatialCoordinate(mesh)
>>> DG = FunctionSpace(mesh, "DG", 0)
>>> DGT = FunctionSpace(mesh, "DGT", 0)

Mark a cell subdomain and construct a codim-0 submesh from all cells in the subdomain

>>> cell_marker = assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG))
>>> mesh.mark_entities(cell_marker, 111)
>>> submesh = Submesh(mesh, subdomain_id=111)

Mark a facet subdomain and construct a codim-1 submesh from all facets in the subdomain

>>> facet_marker = assemble(interpolate(conditional(lt(abs(x-0.5), 1E-12), 1, 0), DGT))
>>> mesh.mark_entities(facet_marker, 222)
>>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=222)

Construct a codim-0 submesh of the union of multiple subdomains by passing a list

>>> mesh.mark_entities(assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)), 1)
>>> mesh.mark_entities(assemble(interpolate(conditional(lt(y, 0.5), 1, 0), DG)), 2)
>>> submesh = Submesh(mesh, subdomain_id=[1, 2])

Construct a codim-0 submesh of the intersection of multiple subdomains by passing a nested list

>>> submesh = Submesh(mesh, subdomain_id=[(1, 2)])

Construct a codim-1 submesh of all the facets (the skeleton mesh)

>>> submesh = Submesh(mesh, subdim=1)

Construct a codim-1 submesh of the entire boundary

>>> submesh = Submesh(mesh, subdomain_id="on_boundary")

Construct a codim-1 submesh of the union of multiple boundaries

>>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=[1, 2, 3])

Construct a codim-0 submesh of the part of the mesh owned by each MPI rank

>>> submesh = Submesh(mesh, ignore_halo=True, comm=COMM_SELF)

"""
if not isinstance(mesh, MeshGeometry):
raise TypeError("Parent mesh must be a `MeshGeometry`")
if isinstance(mesh.topology, ExtrudedMeshTopology):
raise NotImplementedError("Can not create a submesh of an ``ExtrudedMesh``")
elif isinstance(mesh.topology, VertexOnlyMeshTopology):
raise NotImplementedError("Can not create a submesh of a ``VertexOnlyMesh``")

if subdomain_id == "on_boundary":
if subdim is None:
subdim = mesh.topological_dimension - 1
elif subdim != mesh.topological_dimension - 1:
raise ValueError('subdomain_id="on_boundary" requires subdim=dim-1')
if label_name is None:
label_name = "exterior_facets"
elif label_name != "exterior_facets":
raise ValueError('subdomain_id="on_boundary" requires label_name="exterior_facets"')
subdomain_id = 1

if subdim is None:
subdim = mesh.topological_dimension
plex = mesh.topology_dm
Expand All @@ -4876,15 +4935,53 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
label_name = dmcommon.CELL_SETS_LABEL
elif subdim == dim - 1:
label_name = dmcommon.FACE_SETS_LABEL

# Parse non-integer subdomain_id
if isinstance(subdomain_id, str):
raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.")
elif isinstance(subdomain_id, Sequence):
label = plex.getLabel(label_name)
if subdim != dim:
plex.labelComplete(label)

def get_points(sub):
if sub == "on_boundary":
return plex.getStratumIS("exterior_facets", 1)
elif isinstance(sub, str):
raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.")
else:
return label.getStratumIS(sub)

# Take the union of the labels in the list
iset = PETSc.IS().createGeneral([], comm=mesh.comm)
for sub in subdomain_id:
if isinstance(sub, Sequence) and not isinstance(sub, str):
# Take the intersection of the labels from nested lists
if len(sub) == 0:
continue
cur = reduce(dmcommon.intersectIS, map(get_points, sub))
else:
cur = get_points(sub)
iset = iset.union(cur)
# Create a temporary label
label_name = "temp_label"
subdomain_id = 1
plex.createLabel(label_name)
label = plex.getLabel(label_name)
label.setStratumIS(subdomain_id, iset)

subplex = dmcommon.submesh_create(plex, subdim, label_name, subdomain_id, ignore_halo, comm=comm)

if label_name == "temp_label":
plex.removeLabel(label_name)

comm = comm or mesh.comm
name = name or _generate_default_submesh_name(mesh.name)
subplex.setName(_generate_default_mesh_topology_name(name))
if subplex.getDimension() != subdim:
raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})")
if reorder is None:
# Ideally we should set perm_is = mesh.dm_reordering[label_indices]
# Ideally we should set perm_is = mesh._dm_renumbering[label_indices]
reorder = mesh._did_reordering

submesh = Mesh(
Expand Down
107 changes: 107 additions & 0 deletions tests/firedrake/submesh/test_submesh_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest
import numpy as np
from firedrake import *


def test_submesh_subdomain_id_union():
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [111, 222]
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)

m3 = Function(M).interpolate(m1 + m2 - m1 * m2)
expected = assemble(m3*dx)
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

mesh.mark_entities(m3, 333)
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


def test_submesh_subdomain_id_intersection():
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [(111, 222)]
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)

m3 = Function(M).interpolate(m1 * m2)
expected = assemble(m3*dx)
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

mesh.mark_entities(m3, 333)
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


@pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)])
def test_submesh_facet_subdomain_id_union(subdomain_id):
mesh = UnitCubeMesh(2, 2, 2)
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id)
if subdomain_id == "on_boundary":
area = assemble(1*ds(domain=mesh))
else:
area = assemble(1*ds(subdomain_id, domain=mesh))
assert abs(assemble(1*dx(domain=submesh1)) - area) < 1E-12

DGT = FunctionSpace(mesh, "DGT", 0)
facet_function = Function(DGT)
DirichletBC(DGT, 1, subdomain_id).apply(facet_function)
facet_value = 999
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


@pytest.mark.parametrize("sub", ["cell-cell", "cell-boundary"])
def test_submesh_facet_subdomain_id_intersection(sub):
if sub == "cell-cell":
# (x <= 0.5) & (x >= 0.5)
subdomain_id = [(111, 222)]
expected = 1
elif sub == "cell-boundary":
# (x <= 0.5) & (x == 0 | y == 0 | y == 1)
subdomain_id = [(111, "on_boundary")]
expected = 2

mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
DG = FunctionSpace(mesh, "DG", 0)
DGT = FunctionSpace(mesh, "DGT", 0)
m1 = Function(DG).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(DG).interpolate(conditional(lt(x, 0.5), 0, 1))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id, label_name="Cell Sets")

assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

facet_function = Function(DGT)
if sub == "cell-cell":
facet_function.interpolate(conditional(lt(abs(x-0.5), 1E-8), 1, 0))
elif sub == "cell-boundary":
facet_function.interpolate(conditional(lt(x, 0.5), 1, 0))
bnd = Function(DGT)
DirichletBC(DGT, 1, "on_boundary").apply(bnd)
facet_function.dat.data[:] *= bnd.dat.data_ro[:]

facet_value = 999
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
Loading