Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4425eae
dsl: Introduce SparseEq base for Interpolation/Injection
mloubout May 13, 2026
3fe3fbb
compiler: Lower SparseEq via rcompile efunc (WIP)
mloubout May 13, 2026
cbd4c73
compiler: Run SparseEq efunc inside parent's time loop
mloubout May 13, 2026
fbb6960
tests: Update structure asserts for sparse-op efunc lowering
mloubout May 13, 2026
cbbbbf4
dsl: Collapse Interpolation/Injection into SparseEq
mloubout May 13, 2026
608b2de
compiler: Drop rcompile for SparseEq efunc lowering
mloubout May 14, 2026
be13e30
compiler: Strip time loop and headers from SparseEq efunc
mloubout May 14, 2026
b69abaa
compiler: Emit SparseEq efunc as plain Callable for GPU passes
mloubout May 14, 2026
5f3d4f9
compiler: Move SparseEq efunc lowering into core _lower_iet
mloubout May 14, 2026
6f6e1bc
compiler: Run SparseEq efunc lowering after halospot optimization
mloubout May 19, 2026
0e1da06
dsl: Tighten SparseEq lowering and SubDim handling
mloubout May 19, 2026
f0429c1
tests: Update structure asserts for sparse-op efunc lowering
mloubout May 19, 2026
596425c
compiler: Replace Eq.lower/clusterize methods with singledispatch
mloubout May 19, 2026
15aff36
compiler: Trim sparse-op IET pass and lift its import
mloubout May 19, 2026
09d1a7c
tests: Tighten sparse-op Call assertions to exact names
mloubout May 19, 2026
0f1f9eb
dsl: Replace SparseEq kind string with Interpolation/Injection subcla…
mloubout May 26, 2026
7e07f72
compiler: Lift lower_sparse_ops out of mpiize into _lower_iet
mloubout May 29, 2026
4519249
compiler: Shed reduction-only halos when lowering injections
mloubout May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from devito.ir.equations import ClusterizedEq
from devito.ir.equations import clusterize_eq
from devito.ir.support import (
PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext, DataSpace, Forward, Guards, Interval,
IntervalGroup, IterationSpace, PrefetchUpdate, Properties, Scope, WaitLock, WithLock,
Expand Down Expand Up @@ -50,7 +50,7 @@ class Cluster:

def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None,
syncs=None, halo_scheme=None):
self._exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs))
self._exprs = tuple(clusterize_eq(e, ispace=ispace) for e in as_tuple(exprs))
self._ispace = ispace
self._guards = Guards(guards or {})
self._syncs = normalize_syncs(syncs or {})
Expand Down
175 changes: 161 additions & 14 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import cached_property
from functools import cached_property, singledispatch

import numpy as np
import sympy
Expand All @@ -11,17 +11,31 @@
)
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min
from devito.types import (
Eq, Inc, IncrInterpolation, Injection, InjectionMixin, Interpolation,
InterpolationMixin, ReduceMax, ReduceMin, ReduceMinMax, SparseEq, SparseOpMixin,
relational_min
)

__all__ = [
'ClusterizedEq',
'ClusterizedIncrInterpolation',
'ClusterizedInjection',
'ClusterizedInterpolation',
'ClusterizedSparseEq',
'DummyEq',
'LoweredEq',
'LoweredIncrInterpolation',
'LoweredInjection',
'LoweredInterpolation',
'LoweredSparseEq',
'OpInc',
'OpMax',
'OpMin',
'OpMinMax',
'clusterize_eq',
'identity_mapper',
'lower_eq',
]


Expand All @@ -30,6 +44,8 @@ class IREq(sympy.Eq, Pickable):
__rargs__ = ('lhs', 'rhs')
__rkwargs__ = ('ispace', 'conditionals', 'implicit_dims', 'operation')

is_SparseOperation = False

def _hashable_content(self):
return (*super()._hashable_content(),
*tuple(getattr(self, i) for i in self.__rkwargs__))
Expand Down Expand Up @@ -115,16 +131,15 @@ class Operation(Tag):

@classmethod
def detect(cls, expr):
reduction_mapper = {
Inc: OpInc,
ReduceMax: OpMax,
ReduceMin: OpMin,
ReduceMinMax: OpMinMax
}
try:
return reduction_mapper[type(expr)]
except KeyError:
pass
reduction_mapper = (
(ReduceMinMax, OpMinMax),
(ReduceMin, OpMin),
(ReduceMax, OpMax),
(Inc, OpInc),
)
for kls, op in reduction_mapper:
if isinstance(expr, kls):
return op

# NOTE: in the future we might want to track down other kinds
# of operations here (e.g., memcpy). However, we don't care for
Expand Down Expand Up @@ -204,8 +219,9 @@ def __new__(cls, *args, **kwargs):
accesses = detect_accesses(expr)
dimensions = Stencil.union(*accesses.values())

# Separate out the SubIterators from the main iteration Dimensions, that
# is those which define an actual iteration space
# Separate out the SubIterators from the main iteration
# Dimensions, that is those which define an actual
# iteration space
iterators = {}
for d in dimensions:
if d.is_SubIterator:
Expand Down Expand Up @@ -271,6 +287,43 @@ def func(self, *args):
return self._rebuild(*args, evaluate=False)


class LoweredSparseEq(SparseOpMixin, LoweredEq):

"""
The IR counterpart of ``SparseEq``: a regular ``LoweredEq`` that
also carries the ``interpolator`` metadata used by the IET pass
``lower_sparse_ops`` to wrap the resulting ``p_*, rp_*`` iteration
nest in an ElementalFunction. Subclassed by
``LoweredInterpolation`` / ``LoweredIncrInterpolation`` /
``LoweredInjection`` for the per-operation polymorphic behaviour.
"""

__rkwargs__ = LoweredEq.__rkwargs__ + ('interpolator',)


class LoweredInterpolation(InterpolationMixin, LoweredSparseEq):
"""IR counterpart of ``Interpolation``."""
pass


class LoweredIncrInterpolation(InterpolationMixin, LoweredSparseEq):
"""IR counterpart of ``IncrInterpolation``."""
pass


class LoweredInjection(InjectionMixin, LoweredSparseEq):
"""IR counterpart of ``Injection``."""
pass


# Map user-level sparse-op classes to their IR-level counterparts.
_lowered_sparse_cls = {
Interpolation: LoweredInterpolation,
IncrInterpolation: LoweredIncrInterpolation,
Injection: LoweredInjection,
}


class ClusterizedEq(IREq):

"""
Expand Down Expand Up @@ -326,6 +379,41 @@ def __new__(cls, *args, **kwargs):
func = IREq._rebuild


class ClusterizedSparseEq(SparseOpMixin, ClusterizedEq):

"""
Frozen counterpart of ``LoweredSparseEq``: the same regular
``ClusterizedEq`` augmented with ``interpolator`` so the IET pass
``lower_sparse_ops`` can identify and rewrite the sparse op's
iteration nest. Subclassed by ``ClusterizedInterpolation`` /
``ClusterizedIncrInterpolation`` / ``ClusterizedInjection``.
"""

__rkwargs__ = ClusterizedEq.__rkwargs__ + ('interpolator',)


class ClusterizedInterpolation(InterpolationMixin, ClusterizedSparseEq):
"""Frozen counterpart of ``LoweredInterpolation``."""
pass


class ClusterizedIncrInterpolation(InterpolationMixin, ClusterizedSparseEq):
"""Frozen counterpart of ``LoweredIncrInterpolation``."""
pass


class ClusterizedInjection(InjectionMixin, ClusterizedSparseEq):
"""Frozen counterpart of ``LoweredInjection``."""
pass


_clusterized_sparse_cls = {
LoweredInterpolation: ClusterizedInterpolation,
LoweredIncrInterpolation: ClusterizedIncrInterpolation,
LoweredInjection: ClusterizedInjection,
}


class DummyEq(ClusterizedEq):

"""
Expand All @@ -345,3 +433,62 @@ def __new__(cls, *args, **kwargs):
else:
raise ValueError(f"Cannot construct DummyEq from args={str(args)}")
return ClusterizedEq.__new__(cls, obj, ispace=obj.ispace)


@singledispatch
def lower_eq(eq):
"""
Promote a user-level ``Eq`` to its ``LoweredEq`` counterpart, ready
for the cluster pipeline. The dispatch matches the dynamic type of
``eq``; ``SparseEq`` and friends get their own branch.
"""
return LoweredEq(eq)


@lower_eq.register(SparseEq)
def _(eq):
# Augment ``implicit_dims`` with the SparseFunction's own iteration
# Dimensions (e.g. ``p_sf`` and any extra SparseFunction dims) so
# the cluster scheduler sees them. Grid Dimensions reached through
# the rhs Function are deliberately *not* added: SubDomain-derived
# SubDimensions would otherwise spuriously appear in the
# IterationSpace, and grid Dimensions are already discovered via
# the radius ConditionalDimensions in the rhs.
interp = eq.interpolator
sf_dims = tuple(interp.sfunction.dimensions)
user = tuple(eq.implicit_dims or ())
if interp.sfunction._sparse_position == -1:
augmented = sf_dims + user
else:
augmented = user + sf_dims

if augmented != tuple(eq.implicit_dims or ()):
eq = eq.func(eq.lhs, eq.rhs, interpolator=interp,
implicit_dims=augmented)

lowered_cls = _lowered_sparse_cls[type(eq)]
obj = lowered_cls(eq)
obj._interpolator = interp
return obj


@singledispatch
def clusterize_eq(eq, **kwargs):
"""
Freeze a ``LoweredEq`` into its ``ClusterizedEq`` counterpart,
suitable for use in a ``Cluster``. Subclasses with extra payload
(e.g. ``LoweredSparseEq``) dispatch to their frozen counterpart.
"""
return ClusterizedEq(eq, **kwargs)


@clusterize_eq.register(LoweredSparseEq)
def _(eq, **kwargs):
return _clusterized_sparse_cls[type(eq)](eq, **kwargs)


@clusterize_eq.register(ClusterizedSparseEq)
def _(eq, **kwargs):
# ``eq`` is already clusterized; rebuild via its own class to preserve
# the per-operation polymorphic behaviour.
return type(eq)(eq, **kwargs)
1 change: 0 additions & 1 deletion devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ class HaloScheme:
"""

def __init__(self, exprs, ispace):
# Derive the halo exchanges
self._mapper = frozendict(classify(exprs, ispace))

# Track the IterationSpace offsets induced by SubDomains/SubDimensions,
Expand Down
Loading
Loading