Skip to content

Commit 7e07f72

Browse files
committed
compiler: Lift lower_sparse_ops out of mpiize into _lower_iet
Sparse-op lowering belongs at the IET-pass level alongside other target-independent structural lowerings, not buried inside the MPI pass. Lifting it also fixes a missed host-device transfer on GPU: the position/coefficient temps it materialises read the coordinate SubFunctions, but those reads were invisible to place_transfers because Graph.data_movs was snapshotted before mpiize ran. Graph.data_movs is now a property recomputed from the live efunc set so downstream passes see ExprStmts introduced after Graph construction.
1 parent 0f1f9eb commit 7e07f72

3 files changed

Lines changed: 29 additions & 25 deletions

File tree

devito/operator/operator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from devito.parameters import configuration
3434
from devito.passes import (
3535
Graph, error_mapper, generate_implicit, generate_macros, is_on_device, lower_dtypes,
36-
lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate
36+
lower_index_derivatives, lower_sparse_ops, minimize_symbols, optimize_pows, unevaluate
3737
)
3838
from devito.symbolics import estimate_cost, subs_op_args
3939
from devito.tools import (
@@ -490,13 +490,12 @@ def _lower_iet(cls, uiet, **kwargs):
490490
parameters = derive_parameters(uiet, True)
491491
iet = EntryFunction(name, uiet, 'int', parameters, ())
492492

493-
# Lower IET to a target-specific IET. Sparse-op lowering
494-
# (``lower_sparse_ops``) is now run from inside ``mpiize``,
495-
# between the halo-optimisation phase and the halo-exchange
496-
# injection, so the reduction heuristic gets a chance to drop
497-
# redundant halo exchanges around the sparse-op nest before
498-
# the nest is sealed into an ElementalFunction.
493+
# Materialise the sparse-op iteration nests into ElementalFunctions
494+
# before target specialisation, so the position/coefficient temps
495+
# the IET pass emits are visible to downstream passes (e.g. the
496+
# data-transfer placement on device).
499497
graph = Graph(iet, **kwargs)
498+
lower_sparse_ops(graph, **kwargs)
500499
graph = cls._specialize_iet(graph, **kwargs)
501500

502501
# Instrument the IET for C-level profiling

devito/passes/iet/engine.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,29 @@ def __init__(self, iet, options=None, sregistry=None, **kwargs):
8181
writes = FindSymbols('writes').visit(iet)
8282
self.writes_input = frozenset(f for f in writes if f.is_Input)
8383

84-
# All symbols requiring host-device data transfers when running
85-
# on device
86-
self.data_movs = rmovs, wmovs = set(), set()
87-
gpu_fit = (options or {}).get('gpu-fit', ())
88-
for i in FindNodes(ExprStmt).visit(iet):
89-
wmovs.update({w for w in i.writes
90-
if needs_transfer(w, gpu_fit)})
91-
for i in FindNodes(ExprStmt).visit(iet):
92-
rmovs.update({f for f in i.functions
93-
if needs_transfer(f, gpu_fit) and f not in wmovs})
84+
self._gpu_fit = (options or {}).get('gpu-fit', ())
85+
86+
@property
87+
def data_movs(self):
88+
"""
89+
``(reads, writes)`` of symbols requiring host-device data transfers.
90+
91+
Recomputed on access from the current state of the Graph (root
92+
plus every reachable efunc) so passes that introduce ExprStmts
93+
after construction — e.g. ``lower_sparse_ops`` materialising the
94+
sparse-op position temps inside an efunc — are reflected.
95+
"""
96+
rmovs, wmovs = set(), set()
97+
for efunc in self.efuncs.values():
98+
for i in FindNodes(ExprStmt).visit(efunc):
99+
wmovs.update(w for w in i.writes
100+
if needs_transfer(w, self._gpu_fit))
101+
for efunc in self.efuncs.values():
102+
for i in FindNodes(ExprStmt).visit(efunc):
103+
rmovs.update(f for f in i.functions
104+
if needs_transfer(f, self._gpu_fit)
105+
and f not in wmovs)
106+
return rmovs, wmovs
94107

95108
@property
96109
def root(self):

devito/passes/iet/mpi.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from devito.mpi.reduction_scheme import DistReduce
1212
from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
1313
from devito.passes.iet.engine import iet_pass
14-
from devito.passes.iet.sparse import lower_sparse_ops
1514
from devito.symbolics import VectorAccess, search
1615
from devito.tools import generator
1716
from devito.types import TensorMove
@@ -390,11 +389,6 @@ def mpiize(graph, **kwargs):
390389
Perform three IET passes:
391390
392391
* Optimization of halo exchanges
393-
* Lower sparse operations (Interpolation/Injection) into Calls
394-
to ElementalFunctions. This runs after halo optimisation so
395-
the reduction heuristic gets a chance to drop redundant halo
396-
exchanges around the sparse-op nest before it is sealed into
397-
an efunc.
398392
* Injection of code for halo exchanges
399393
* Injection of code for reductions
400394
"""
@@ -403,8 +397,6 @@ def mpiize(graph, **kwargs):
403397
if options['opt-comms']:
404398
optimize_halospots(graph, **kwargs)
405399

406-
lower_sparse_ops(graph, **kwargs)
407-
408400
mpimode = options['mpi']
409401
if mpimode:
410402
make_halo_exchanges(graph, mpimode=mpimode, **kwargs)

0 commit comments

Comments
 (0)