Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions grudge/symbolic/mappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import pymbolic.mapper.constant_converter
import pymbolic.mapper.flop_counter
from pymbolic.mapper import CSECachingMapperMixin
from pymbolic.mapper.equality import EqualityMapper as EqualityMapperBase

from grudge import sym
import grudge.dof_desc as dof_desc
Expand Down Expand Up @@ -1295,4 +1296,46 @@ def map_common_subexpression(self, expr):
# }}}


# {{{ equality

class EqualityMapper(EqualityMapperBase):
def map_ones(self, expr, other) -> bool:
return expr.dd == other.dd

def map_grudge_variable(self, expr, other) -> bool:
return (
expr.name == other.name
and expr.dd == other.dd)

def map_node_coordinate_component(self, expr, other) -> bool:
return (
expr.axis == other.axis
and expr.dd == other.dd)

def map_operator_binding(self, expr, other) -> bool:
return (
self.rec(expr.op, other.op)
and self.rec(expr.field, other.field))

def map_ref_diff(self, expr, other) -> bool:
return (
expr.rst_axis == other.rst_axis
and expr.dd_in == other.dd_in
and expr.dd_out == other.dd_out)

map_ref_stiffness_t = map_ref_diff

def map_elementwise_linear(self, expr, other) -> bool:
return (
expr.dd_in == other.dd_in
and expr.dd_out == other.dd_out)

map_ref_mass = map_elementwise_linear
map_ref_inverse_mass = map_elementwise_linear
map_face_mass_operator = map_elementwise_linear
map_ref_face_mass_operator = map_elementwise_linear
map_projection = map_elementwise_linear

# }}}

# vim: foldmethod=marker
5 changes: 5 additions & 0 deletions grudge/symbolic/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def __getinitargs__(self):
def make_stringifier(self, originating_stringifier=None):
from grudge.symbolic.mappers import StringifyMapper
return StringifyMapper()

def make_equality_mapper(self):
from grudge.symbolic.mappers import EqualityMapper
return EqualityMapper()

# }}}


Expand Down
4 changes: 4 additions & 0 deletions grudge/symbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def make_stringifier(self, originating_stringifier=None):
from grudge.symbolic.mappers import StringifyMapper
return StringifyMapper()

def make_equality_mapper(self):
from grudge.symbolic.mappers import EqualityMapper
return EqualityMapper()


__doc__ = """

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
numpy
mpi4py
git+https://github.com/inducer/pytools.git#egg=pytools
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
git+https://github.com/inducer/islpy.git#egg=islpy
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
git+https://github.com/inducer/loopy.git#egg=loopy
git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy
git+https://github.com/inducer/dagrt.git#egg=dagrt
git+https://github.com/inducer/leap.git#egg=leap
git+https://github.com/inducer/meshpy.git#egg=meshpy
Expand All @@ -14,7 +14,7 @@ git+https://github.com/inducer/meshmode.git#egg=meshmode
git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile
git+https://github.com/inducer/pymetis.git#egg=pymetis
git+https://github.com/illinois-ceesd/logpyle.git#egg=logpyle
git+https://github.com/inducer/pytato.git#egg=pytato
git+https://github.com/alexfikl/pytato.git@equality-mapper#egg=pytato

# for test_wave_dt_estimate
sympy
2 changes: 2 additions & 0 deletions test/test_mpi_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def run_test_with_mpi_inner():

# {{{ func_comparison

@pytest.mark.mpi
@pytest.mark.parametrize("actx_class", DISTRIBUTED_ACTXS)
@pytest.mark.parametrize("num_ranks", [2])
def test_func_comparison_mpi(actx_class, num_ranks):
Expand Down Expand Up @@ -177,6 +178,7 @@ def hopefully_zero():

# {{{ wave operator

@pytest.mark.mpi
@pytest.mark.parametrize("actx_class", DISTRIBUTED_ACTXS)
@pytest.mark.parametrize("num_ranks", [2])
def test_mpi_wave_op(actx_class, num_ranks):
Expand Down