Skip to content

Commit ea8ad37

Browse files
committed
compiler: Trim sparse-op IET pass and lift its import
1 parent 5ad6313 commit ea8ad37

2 files changed

Lines changed: 41 additions & 97 deletions

File tree

devito/passes/iet/mpi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from devito.mpi.reduction_scheme import DistReduce
1212
from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
1313
from devito.passes.iet.engine import iet_pass
14+
from devito.passes.iet.sparse import lower_sparse_ops
1415
from devito.symbolics import VectorAccess, search
1516
from devito.tools import generator
1617
from devito.types import TensorMove
@@ -402,7 +403,6 @@ def mpiize(graph, **kwargs):
402403
if options['opt-comms']:
403404
optimize_halospots(graph, **kwargs)
404405

405-
from devito.passes.iet.sparse import lower_sparse_ops
406406
lower_sparse_ops(graph, **kwargs)
407407

408408
mpimode = options['mpi']

devito/passes/iet/sparse.py

Lines changed: 40 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,13 @@
3737
__all__ = ['lower_sparse_ops']
3838

3939

40-
def lower_sparse_ops(graph, **kwargs):
40+
@iet_pass
41+
def lower_sparse_ops(iet, sregistry=None, **kwargs):
4142
"""
4243
Replace each sparse-op iteration nest in the IET with a Call to an
4344
ElementalFunction that materialises the position temporaries and
4445
the inner accumulator/increment pattern.
4546
"""
46-
_lower_sparse_ops(graph, **kwargs)
47-
48-
49-
@iet_pass
50-
def _lower_sparse_ops(iet, sregistry=None, **kwargs):
5147
if not isinstance(iet, EntryFunction):
5248
return iet, {}
5349

@@ -72,7 +68,7 @@ def _lower_sparse_ops(iet, sregistry=None, **kwargs):
7268
groups.setdefault(nest, []).append(expr)
7369

7470
# If a sparse-op nest sits inside a HaloSpot whose halo scheme is
75-
# void (e.g. the reduction-only halo got dropped by
71+
# void (the reduction-only halo got dropped by
7672
# ``_drop_reduction_halospots``), replace the HaloSpot rather than
7773
# just the nest so we don't leave behind an empty HaloSpot — the
7874
# MPI overlap machinery would otherwise try to wrap our Call with
@@ -81,38 +77,20 @@ def _lower_sparse_ops(iet, sregistry=None, **kwargs):
8177

8278
mapper = {}
8379
efuncs = []
84-
8580
for nest, exprs in groups.items():
8681
new_nest = _materialise_nest(nest, exprs)
8782

88-
name = sregistry.make_name(prefix=_efunc_prefix(exprs[0].expr))
89-
efunc = make_callable(name, new_nest)
83+
lse = exprs[0].expr
84+
prefix = f'{lse.kind}_{lse.interpolator.sfunction.name}'
85+
efunc = make_callable(sregistry.make_name(prefix=prefix), new_nest)
9086
efuncs.append(efunc)
9187

92-
call = Call(efunc.name, list(efunc.parameters))
93-
target = parents[nest] or nest
94-
mapper[target] = call
88+
mapper[parents[nest] or nest] = Call(efunc.name, list(efunc.parameters))
9589

9690
if not mapper:
9791
return iet, {}
9892

99-
iet = Transformer(mapper).visit(iet)
100-
101-
return iet, {'efuncs': efuncs}
102-
103-
104-
def _enclosing_void_halospot(iet, nest):
105-
"""
106-
Return the HaloSpot directly wrapping ``nest`` if it carries an
107-
empty (void) HaloScheme, otherwise None. Such HaloSpots are leftover
108-
after ``_drop_reduction_halospots`` cleared all entries.
109-
"""
110-
for hs in FindNodes(HaloSpot).visit(iet):
111-
if not hs.is_void:
112-
continue
113-
if nest in FindNodes(Iteration).visit(hs):
114-
return hs
115-
return None
93+
return Transformer(mapper).visit(iet), {'efuncs': efuncs}
11694

11795

11896
def _is_head(eq):
@@ -127,8 +105,6 @@ def _is_head(eq):
127105
f = eq.lhs.function
128106
if eq.kind == 'interpolate':
129107
return f is sf
130-
# 'inject': head writes into a DiscreteFunction (the grid field),
131-
# not into a scalar temporary
132108
return f.is_DiscreteFunction and f is not sf
133109

134110

@@ -139,13 +115,23 @@ def _find_outer_iteration(iet, expr):
139115
"""
140116
sparse_dim = expr.expr.interpolator.sfunction._sparse_dim
141117
for it in FindNodes(Iteration).visit(iet):
142-
if it.dim.root is not sparse_dim:
143-
continue
144-
if expr in FindNodes(Expression).visit(it):
118+
if it.dim.root is sparse_dim and expr in FindNodes(Expression).visit(it):
145119
return it
146120
return None
147121

148122

123+
def _enclosing_void_halospot(iet, nest):
124+
"""
125+
Return the HaloSpot directly wrapping ``nest`` if it carries an
126+
empty (void) HaloScheme, otherwise None. Such HaloSpots are leftover
127+
after ``_drop_reduction_halospots`` cleared all entries.
128+
"""
129+
for hs in FindNodes(HaloSpot).visit(iet):
130+
if hs.is_void and nest in FindNodes(Iteration).visit(hs):
131+
return hs
132+
return None
133+
134+
149135
def _materialise_nest(nest, exprs):
150136
"""
151137
Rewrite the sparse Dimension's Iteration body to compute the
@@ -154,45 +140,38 @@ def _materialise_nest(nest, exprs):
154140
pattern. Multiple sparse-op Expressions sharing the same outer
155141
Iteration are materialised in one pass and reuse the same temps.
156142
"""
157-
interp = exprs[0].expr.interpolator
158-
sample_lse = exprs[0].expr
143+
sample = exprs[0].expr
144+
interp = sample.interpolator
159145

160146
# Position + coefficient temporaries as IET Expressions. These are
161147
# the same for every Expression in the group, so we emit them once.
162-
temps = interp._sparse_temps(
163-
sample_lse.kind, _user_expr(sample_lse),
164-
field=_user_field(sample_lse),
165-
implicit_dims=sample_lse.implicit_dims,
166-
)
148+
field = sample.lhs.function if sample.kind == 'inject' else None
149+
temps = interp._sparse_temps(sample.kind, sample.rhs, field=field,
150+
implicit_dims=sample.implicit_dims)
167151
temp_exprs = tuple(Expression(DummyEq(e.lhs, e.rhs))
168152
for e in lower_exprs(temps))
169153

170-
# For each interpolation Expression in the group, build its
171-
# accumulator-wrapped radius nest. Injection Exprs are left where
172-
# they are in the radius nest (their Inc is already the right
173-
# form); injection Exprs share a single copy of the radius nest.
174-
inner = _drop_outer(nest)
154+
# The radius nest is what runs once per sparse point. For each
155+
# interpolation Expression in the group, build its
156+
# accumulator-wrapped copy of the radius nest. Injection Exprs
157+
# share a single copy of the radius nest (their ``Inc`` already
158+
# carries the right ``weights * rhs`` form).
159+
inner = nest.nodes[0] if len(nest.nodes) == 1 else List(body=nest.nodes)
175160
interp_exprs = [e for e in exprs if e.expr.kind == 'interpolate']
176161
inject_exprs = [e for e in exprs if e.expr.kind == 'inject']
177162

178163
body = []
179164
for expr in interp_exprs:
180-
# Build the per-interpolation accumulator: substitute siblings
181-
# out and replace ``expr`` with the increment in a single
182-
# Transformer pass so the radius sub-tree contains only the
183-
# head's increment.
184-
body.append(_interp_inner_block(inner, expr, expr.expr.interpolator,
185-
siblings=[e for e in exprs if e is not expr]))
165+
siblings = [e for e in exprs if e is not expr]
166+
body.append(_interp_inner_block(inner, expr, expr.expr.interpolator, siblings))
186167
if inject_exprs:
187-
# Injections share one radius nest with no interpolation heads.
188-
others = {e: None for e in interp_exprs}
189-
local_inner = Transformer(others, nested=True).visit(inner) if others else inner
190-
body.append(local_inner)
168+
drop = {e: None for e in interp_exprs}
169+
body.append(Transformer(drop, nested=True).visit(inner) if drop else inner)
191170

192171
return nest._rebuild(nodes=temp_exprs + tuple(body))
193172

194173

195-
def _interp_inner_block(inner, expr, interp, siblings=()):
174+
def _interp_inner_block(inner, expr, interp, siblings):
196175
"""
197176
Build the accumulator/radius/write-back triple for an interpolation:
198177
@@ -232,8 +211,7 @@ def _interp_inner_block(inner, expr, interp, siblings=()):
232211
for rd in weights.free_symbols
233212
if getattr(rd, 'is_Conditional', False) and rd.name in rdims_concrete
234213
})
235-
weights_expr = lower_exprs(_make_eq(acc, weights)).rhs
236-
weighted_rhs = weights_expr * rhs
214+
weighted_rhs = lower_exprs(Eq(acc, weights)).rhs * rhs
237215

238216
init = Expression(DummyEq(acc, 0))
239217
inc = Increment(DummyEq(acc, weighted_rhs))
@@ -249,15 +227,14 @@ def _interp_inner_block(inner, expr, interp, siblings=()):
249227

250228
radius_root = _find_radius_root(inner, interp.sfunction)
251229
if radius_root is None or radius_root is inner:
252-
# No intermediate Iteration: wrap the whole ``inner`` directly.
253230
return List(body=(init,
254231
Transformer(mapper, nested=True).visit(inner),
255232
write_back))
256233

257234
# Wrap the accumulator pattern around just the radius sub-tree,
258235
# leaving the enclosing non-radius Iterations in place.
259-
new_radius = Transformer(mapper, nested=True).visit(radius_root)
260-
wrapped = List(body=(init, new_radius, write_back))
236+
wrapped = List(body=(init, Transformer(mapper, nested=True).visit(radius_root),
237+
write_back))
261238
return Transformer({radius_root: wrapped}, nested=True).visit(inner)
262239

263240

@@ -273,36 +250,3 @@ def _find_radius_root(inner, sfunction):
273250
if it.dim.name.startswith(prefix):
274251
return it
275252
return None
276-
277-
278-
def _drop_outer(nest):
279-
"""
280-
Return the sub-IET inside ``nest`` (the Iteration over the sparse
281-
Dim) — i.e. the radius nest. ``nest.nodes`` is what runs once per
282-
sparse point.
283-
"""
284-
if len(nest.nodes) == 1:
285-
return nest.nodes[0]
286-
return List(body=nest.nodes)
287-
288-
289-
def _make_eq(lhs, rhs):
290-
"""Helper to wrap a (lhs, rhs) pair as a symbolic Eq for ``lower_exprs``."""
291-
return Eq(lhs, rhs)
292-
293-
294-
def _efunc_prefix(lse):
295-
"""Pick an ElementalFunction name prefix based on the sparse-op kind."""
296-
return f'{lse.kind}_{lse.interpolator.sfunction.name}'
297-
298-
299-
def _user_expr(lse):
300-
"""The user-side expression to feed ``_sparse_temps`` (rhs of the LSE)."""
301-
return lse.rhs
302-
303-
304-
def _user_field(lse):
305-
"""For injection, the destination Function appearing in lhs."""
306-
if lse.kind == 'inject':
307-
return lse.lhs.function
308-
return None

0 commit comments

Comments
 (0)