Skip to content

Commit 6f2a15f

Browse files
committed
dsl: Replace SparseEq kind string with Interpolation/Injection subclasses
The 'kind' attribute and the 'kind == ...' / 'isinstance(eq, Inc)' branches that gated interpolation vs. injection behaviour are replaced by leaf classes (Interpolation, IncrInterpolation, Injection) sharing a small InterpolationMixin / InjectionMixin. The polymorphic surface (efunc_prefix, field, is_head_eq, sparse_temps) lives on the leaves, so the IET pass and the lowering pipeline drop their kind-discriminating conditionals.
1 parent 3c0ddea commit 6f2a15f

5 files changed

Lines changed: 236 additions & 123 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,22 @@
1212
from devito.symbolics import IntDiv, limits_mapper, uxreplace
1313
from devito.tools import Pickable, Tag, frozendict
1414
from devito.types import (
15-
Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, SparseEq, SparseOpMixin, relational_min
15+
Eq, Inc, IncrInterpolation, Injection, InjectionMixin, Interpolation,
16+
InterpolationMixin, ReduceMax, ReduceMin, ReduceMinMax, SparseEq, SparseOpMixin,
17+
relational_min
1618
)
1719

1820
__all__ = [
1921
'ClusterizedEq',
22+
'ClusterizedIncrInterpolation',
23+
'ClusterizedInjection',
24+
'ClusterizedInterpolation',
2025
'ClusterizedSparseEq',
2126
'DummyEq',
2227
'LoweredEq',
28+
'LoweredIncrInterpolation',
29+
'LoweredInjection',
30+
'LoweredInterpolation',
2331
'LoweredSparseEq',
2432
'OpInc',
2533
'OpMax',
@@ -283,12 +291,37 @@ class LoweredSparseEq(SparseOpMixin, LoweredEq):
283291

284292
"""
285293
The IR counterpart of ``SparseEq``: a regular ``LoweredEq`` that
286-
also carries the ``interpolator``/``kind`` metadata used by the IET
287-
pass ``lower_sparse_ops`` to wrap the resulting ``p_*, rp_*``
288-
iteration nest in an ElementalFunction.
294+
also carries the ``interpolator`` metadata used by the IET pass
295+
``lower_sparse_ops`` to wrap the resulting ``p_*, rp_*`` iteration
296+
nest in an ElementalFunction. Subclassed by
297+
``LoweredInterpolation`` / ``LoweredIncrInterpolation`` /
298+
``LoweredInjection`` for the per-operation polymorphic behaviour.
289299
"""
290300

291-
__rkwargs__ = LoweredEq.__rkwargs__ + ('interpolator', 'kind')
301+
__rkwargs__ = LoweredEq.__rkwargs__ + ('interpolator',)
302+
303+
304+
class LoweredInterpolation(InterpolationMixin, LoweredSparseEq):
305+
"""IR counterpart of ``Interpolation``."""
306+
pass
307+
308+
309+
class LoweredIncrInterpolation(InterpolationMixin, LoweredSparseEq):
310+
"""IR counterpart of ``IncrInterpolation``."""
311+
pass
312+
313+
314+
class LoweredInjection(InjectionMixin, LoweredSparseEq):
315+
"""IR counterpart of ``Injection``."""
316+
pass
317+
318+
319+
# Map user-level sparse-op classes to their IR-level counterparts.
320+
_lowered_sparse_cls = {
321+
Interpolation: LoweredInterpolation,
322+
IncrInterpolation: LoweredIncrInterpolation,
323+
Injection: LoweredInjection,
324+
}
292325

293326

294327
class ClusterizedEq(IREq):
@@ -350,12 +383,35 @@ class ClusterizedSparseEq(SparseOpMixin, ClusterizedEq):
350383

351384
"""
352385
Frozen counterpart of ``LoweredSparseEq``: the same regular
353-
``ClusterizedEq`` augmented with ``interpolator``/``kind`` so the
354-
IET pass ``lower_sparse_ops`` can identify and rewrite the sparse
355-
op's iteration nest.
386+
``ClusterizedEq`` augmented with ``interpolator`` so the IET pass
387+
``lower_sparse_ops`` can identify and rewrite the sparse op's
388+
iteration nest. Subclassed by ``ClusterizedInterpolation`` /
389+
``ClusterizedIncrInterpolation`` / ``ClusterizedInjection``.
356390
"""
357391

358-
__rkwargs__ = ClusterizedEq.__rkwargs__ + ('interpolator', 'kind')
392+
__rkwargs__ = ClusterizedEq.__rkwargs__ + ('interpolator',)
393+
394+
395+
class ClusterizedInterpolation(InterpolationMixin, ClusterizedSparseEq):
396+
"""Frozen counterpart of ``LoweredInterpolation``."""
397+
pass
398+
399+
400+
class ClusterizedIncrInterpolation(InterpolationMixin, ClusterizedSparseEq):
401+
"""Frozen counterpart of ``LoweredIncrInterpolation``."""
402+
pass
403+
404+
405+
class ClusterizedInjection(InjectionMixin, ClusterizedSparseEq):
406+
"""Frozen counterpart of ``LoweredInjection``."""
407+
pass
408+
409+
410+
_clusterized_sparse_cls = {
411+
LoweredInterpolation: ClusterizedInterpolation,
412+
LoweredIncrInterpolation: ClusterizedIncrInterpolation,
413+
LoweredInjection: ClusterizedInjection,
414+
}
359415

360416

361417
class DummyEq(ClusterizedEq):
@@ -408,11 +464,11 @@ def _(eq):
408464

409465
if augmented != tuple(eq.implicit_dims or ()):
410466
eq = eq.func(eq.lhs, eq.rhs, interpolator=interp,
411-
kind=eq.kind, implicit_dims=augmented)
467+
implicit_dims=augmented)
412468

413-
obj = LoweredSparseEq(eq)
469+
lowered_cls = _lowered_sparse_cls[type(eq)]
470+
obj = lowered_cls(eq)
414471
obj._interpolator = interp
415-
obj._kind = eq.kind
416472
return obj
417473

418474

@@ -427,6 +483,12 @@ def clusterize_eq(eq, **kwargs):
427483

428484

429485
@clusterize_eq.register(LoweredSparseEq)
486+
def _(eq, **kwargs):
487+
return _clusterized_sparse_cls[type(eq)](eq, **kwargs)
488+
489+
430490
@clusterize_eq.register(ClusterizedSparseEq)
431491
def _(eq, **kwargs):
432-
return ClusterizedSparseEq(eq, **kwargs)
492+
# ``eq`` is already clusterized; rebuild via its own class to preserve
493+
# the per-operation polymorphic behaviour.
494+
return type(eq)(eq, **kwargs)

devito/operations/interpolators.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from devito.logger import warning
1616
from devito.symbolics import INT, retrieve_function_carriers, retrieve_functions
1717
from devito.tools import as_tuple, filter_ordered, memoized_meth
18-
from devito.types import Eq, Inc, SparseEq, SparseInc, SubFunction, Symbol
18+
from devito.types import (
19+
Eq, Inc, IncrInterpolation, Injection, Interpolation, SubFunction, Symbol
20+
)
1921
from devito.types.utils import DimensionTuple
2022

2123
__all__ = ['LinearInterpolator', 'PrecomputedInterpolator', 'SincInterpolator']
@@ -82,40 +84,38 @@ def _extract_subdomain(variables):
8284

8385
def _build_interpolation(expr, increment, implicit_dims, self_subs, interpolator):
8486
"""
85-
Construct the SparseEq for an interpolation: the synthetic Eq is
86-
``Eq(sf[..., p_*], expr[..., rp_*])``; with ``increment`` it is
87+
Construct the sparse-op Eq for an interpolation: the synthetic Eq
88+
is ``Eq(sf[..., p_*], expr[..., rp_*])``; with ``increment`` it is
8789
an ``Inc``. User-supplied ``implicit_dims`` are carried as-is; the
8890
SparseFunction's iteration Dimensions are augmented in by
89-
``SparseEq.lower`` so the cluster pipeline sees them.
91+
``lower_eq`` so the cluster pipeline sees them.
9092
"""
9193
eq = interpolator._interpolate(expr=expr, increment=increment,
9294
self_subs=self_subs,
9395
implicit_dims=None)
94-
cls = SparseInc if isinstance(eq, Inc) else SparseEq
96+
cls = IncrInterpolation if isinstance(eq, Inc) else Interpolation
9597
return cls(eq.lhs, eq.rhs, interpolator=interpolator,
96-
kind='interpolate',
9798
implicit_dims=implicit_dims)
9899

99100

100101
def _build_injection(field, expr, implicit_dims, interpolator):
101102
"""
102-
Construct the SparseEq(s) for an injection: each synthetic Eq is
103-
``Inc(field[..., x, y, ...], weights * expr[..., rp_*])`` produced
104-
by ``interpolator._inject``. A multi-field injection expands into
105-
one ``SparseEq`` per ``(field, expr)`` pair so each target field is
106-
individually visible to the cluster pipeline. User-supplied
107-
``implicit_dims`` are carried as-is; sparse-function iteration
108-
Dimensions are augmented in by ``SparseEq.lower``.
103+
Construct the ``Injection``(s) for an injection: each synthetic Eq
104+
is ``Inc(field[..., x, y, ...], weights * expr[..., rp_*])``
105+
produced by ``interpolator._inject``. A multi-field injection
106+
expands into one ``Injection`` per ``(field, expr)`` pair so each
107+
target field is individually visible to the cluster pipeline.
108+
User-supplied ``implicit_dims`` are carried as-is; sparse-function
109+
iteration Dimensions are augmented in by ``lower_eq``.
109110
"""
110111
fields, exprs = as_tuple(field), as_tuple(expr)
111112
if len(exprs) == 1:
112113
exprs = tuple(exprs[0] for _ in fields)
113114
eqs = []
114115
for (f, e) in zip(fields, exprs, strict=True):
115116
inc = interpolator._inject(field=f, expr=e, implicit_dims=None)
116-
eqs.append(SparseEq(inc.lhs, inc.rhs, interpolator=interpolator,
117-
kind='inject',
118-
implicit_dims=implicit_dims))
117+
eqs.append(Injection(inc.lhs, inc.rhs, interpolator=interpolator,
118+
implicit_dims=implicit_dims))
119119
return eqs[0] if len(eqs) == 1 else eqs
120120

121121

@@ -261,6 +261,25 @@ def _positions(self, implicit_dims, shifts=None):
261261
return [Eq(v, INT(floor(k)), implicit_dims=implicit_dims)
262262
for k, v in self.sfunction._position_map(shifts=shifts).items()]
263263

264+
def sparse_temps(self, rhs, implicit_dims, field=None):
265+
"""
266+
Position/coefficient temps for a sparse op with right-hand side
267+
``rhs``. For an injection, ``field`` drives the per-Dimension
268+
shifts so the temps' lhs (``pos*`` symbols) match the rhs of a
269+
staggered injection; for an interpolation, ``field`` is None
270+
and no shifts are applied.
271+
"""
272+
if field is not None:
273+
extras = [field] + list(retrieve_function_carriers(rhs))
274+
shifts = self._field_shifts(field)
275+
else:
276+
extras = list(retrieve_function_carriers(rhs)) or None
277+
shifts = None
278+
279+
implicit_dims = self._augment_implicit_dims(implicit_dims, extras=extras)
280+
return list(self._positions(implicit_dims, shifts=shifts)) + \
281+
list(self._coeff_temps(implicit_dims, shifts=shifts))
282+
264283
def _interp_idx(self, variables, subdomain=None, shifts=None):
265284
"""
266285
Generate the indirect-access index substitutions for the
@@ -289,30 +308,6 @@ def _interp_idx(self, variables, subdomain=None, shifts=None):
289308

290309
return {v: v.subs(subs) for v in variables}
291310

292-
def _sparse_temps(self, kind, expr, field=None, implicit_dims=None):
293-
"""
294-
Position/coefficient temps emitted alongside the radius
295-
expansion. ``implicit_dims`` is augmented with the
296-
SparseFunction's iteration dimensions (and any dim carried by
297-
the operation inputs) so the temps share the radius-nest's
298-
iteration space. For injection, ``field``'s staggering drives
299-
the position-symbol shifts so the rhs `pos*` symbols match
300-
the temps' lhs.
301-
"""
302-
if kind == 'inject':
303-
extras = list(as_tuple(field))
304-
if expr is not None:
305-
extras.extend(retrieve_function_carriers(expr))
306-
shifts = self._field_shifts(field) if field is not None else None
307-
else:
308-
extras = list(retrieve_function_carriers(expr)) \
309-
if expr is not None else None
310-
shifts = None
311-
312-
implicit_dims = self._augment_implicit_dims(implicit_dims, extras=extras)
313-
return list(self._positions(implicit_dims, shifts=shifts)) + \
314-
list(self._coeff_temps(implicit_dims, shifts=shifts))
315-
316311
@check_radius
317312
@check_coords
318313
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):

devito/passes/iet/sparse.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
Transformer, make_callable
3333
)
3434
from devito.passes.iet.engine import iet_pass
35-
from devito.types import Eq, Symbol
35+
from devito.types import Eq, InjectionMixin, InterpolationMixin, Symbol
3636

3737
__all__ = ['lower_sparse_ops']
3838

@@ -53,7 +53,9 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
5353
# from the original SparseEq (e.g. by ``factorize``/``cse``) are
5454
# left where they are inside the radius nest.
5555
sparse_exprs = [e for e in FindNodes(Expression).visit(iet)
56-
if e.expr.is_SparseOperation and _is_head(e.expr)]
56+
if e.expr.is_SparseOperation
57+
and type(e.expr).is_head_eq(e.expr,
58+
e.expr.interpolator.sfunction)]
5759
if not sparse_exprs:
5860
return iet, {}
5961

@@ -81,7 +83,7 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
8183
new_nest = _materialise_nest(nest, exprs)
8284

8385
lse = exprs[0].expr
84-
prefix = f'{lse.kind}_{lse.interpolator.sfunction.name}'
86+
prefix = f'{lse.efunc_prefix}_{lse.interpolator.sfunction.name}'
8587
efunc = make_callable(sregistry.make_name(prefix=prefix), new_nest)
8688
efuncs.append(efunc)
8789

@@ -93,21 +95,6 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
9395
return Transformer(mapper).visit(iet), {'efuncs': efuncs}
9496

9597

96-
def _is_head(eq):
97-
"""
98-
True if ``eq`` is the "head" of its sparse op: the Expression
99-
whose lhs is the SparseFunction (interpolation) or a
100-
DiscreteFunction grid field (injection), as opposed to an
101-
auxiliary scalar temporary extracted from the original SparseEq by
102-
a cluster pass.
103-
"""
104-
sf = eq.interpolator.sfunction
105-
f = eq.lhs.function
106-
if eq.kind == 'interpolate':
107-
return f is sf
108-
return f.is_DiscreteFunction and f is not sf
109-
110-
11198
def _find_outer_iteration(iet, expr):
11299
"""
113100
Walk up the IET from ``expr`` and return the outermost Iteration
@@ -140,25 +127,22 @@ def _materialise_nest(nest, exprs):
140127
pattern. Multiple sparse-op Expressions sharing the same outer
141128
Iteration are materialised in one pass and reuse the same temps.
142129
"""
143-
sample = exprs[0].expr
144-
interp = sample.interpolator
145-
146130
# Position + coefficient temporaries as IET Expressions. These are
147131
# the same for every Expression in the group, so we emit them once.
148-
field = sample.lhs.function if sample.kind == 'inject' else None
149-
temps = interp._sparse_temps(sample.kind, sample.rhs, field=field,
150-
implicit_dims=sample.implicit_dims)
132+
# The sample's leaf class (Interpolation/Injection) drives whether
133+
# the temps carry staggering shifts.
134+
sample = exprs[0].expr
151135
temp_exprs = tuple(Expression(DummyEq(e.lhs, e.rhs))
152-
for e in lower_exprs(temps))
136+
for e in lower_exprs(sample.sparse_temps()))
153137

154138
# The radius nest is what runs once per sparse point. For each
155139
# interpolation Expression in the group, build its
156140
# accumulator-wrapped copy of the radius nest. Injection Exprs
157141
# share a single copy of the radius nest (their ``Inc`` already
158142
# carries the right ``weights * rhs`` form).
159143
inner = nest.nodes[0] if len(nest.nodes) == 1 else List(body=nest.nodes)
160-
interp_exprs = [e for e in exprs if e.expr.kind == 'interpolate']
161-
inject_exprs = [e for e in exprs if e.expr.kind == 'inject']
144+
interp_exprs = [e for e in exprs if isinstance(e.expr, InterpolationMixin)]
145+
inject_exprs = [e for e in exprs if isinstance(e.expr, InjectionMixin)]
162146

163147
body = []
164148
for expr in interp_exprs:

0 commit comments

Comments
 (0)