3737__all__ = ['lower_sparse_ops' ]
3838
3939
40- def lower_sparse_ops (graph , ** kwargs ):
40+ @iet_pass
41+ def lower_sparse_ops (iet , sregistry = None , ** kwargs ):
4142 """
4243 Replace each sparse-op iteration nest in the IET with a Call to an
4344 ElementalFunction that materialises the position temporaries and
4445 the inner accumulator/increment pattern.
4546 """
46- _lower_sparse_ops (graph , ** kwargs )
47-
48-
49- @iet_pass
50- def _lower_sparse_ops (iet , sregistry = None , ** kwargs ):
5147 if not isinstance (iet , EntryFunction ):
5248 return iet , {}
5349
@@ -72,7 +68,7 @@ def _lower_sparse_ops(iet, sregistry=None, **kwargs):
7268 groups .setdefault (nest , []).append (expr )
7369
7470 # If a sparse-op nest sits inside a HaloSpot whose halo scheme is
75- # void (e.g. the reduction-only halo got dropped by
71+ # void (the reduction-only halo got dropped by
7672 # ``_drop_reduction_halospots``), replace the HaloSpot rather than
7773 # just the nest so we don't leave behind an empty HaloSpot — the
7874 # MPI overlap machinery would otherwise try to wrap our Call with
@@ -81,38 +77,20 @@ def _lower_sparse_ops(iet, sregistry=None, **kwargs):
8177
8278 mapper = {}
8379 efuncs = []
84-
8580 for nest , exprs in groups .items ():
8681 new_nest = _materialise_nest (nest , exprs )
8782
88- name = sregistry .make_name (prefix = _efunc_prefix (exprs [0 ].expr ))
89- efunc = make_callable (name , new_nest )
83+ lse = exprs [0 ].expr
84+ prefix = f'{ lse .kind } _{ lse .interpolator .sfunction .name } '
85+ efunc = make_callable (sregistry .make_name (prefix = prefix ), new_nest )
9086 efuncs .append (efunc )
9187
92- call = Call (efunc .name , list (efunc .parameters ))
93- target = parents [nest ] or nest
94- mapper [target ] = call
88+ mapper [parents [nest ] or nest ] = Call (efunc .name , list (efunc .parameters ))
9589
9690 if not mapper :
9791 return iet , {}
9892
99- iet = Transformer (mapper ).visit (iet )
100-
101- return iet , {'efuncs' : efuncs }
102-
103-
104- def _enclosing_void_halospot (iet , nest ):
105- """
106- Return the HaloSpot directly wrapping ``nest`` if it carries an
107- empty (void) HaloScheme, otherwise None. Such HaloSpots are leftover
108- after ``_drop_reduction_halospots`` cleared all entries.
109- """
110- for hs in FindNodes (HaloSpot ).visit (iet ):
111- if not hs .is_void :
112- continue
113- if nest in FindNodes (Iteration ).visit (hs ):
114- return hs
115- return None
93+ return Transformer (mapper ).visit (iet ), {'efuncs' : efuncs }
11694
11795
11896def _is_head (eq ):
@@ -127,8 +105,6 @@ def _is_head(eq):
127105 f = eq .lhs .function
128106 if eq .kind == 'interpolate' :
129107 return f is sf
130- # 'inject': head writes into a DiscreteFunction (the grid field),
131- # not into a scalar temporary
132108 return f .is_DiscreteFunction and f is not sf
133109
134110
@@ -139,13 +115,23 @@ def _find_outer_iteration(iet, expr):
139115 """
140116 sparse_dim = expr .expr .interpolator .sfunction ._sparse_dim
141117 for it in FindNodes (Iteration ).visit (iet ):
142- if it .dim .root is not sparse_dim :
143- continue
144- if expr in FindNodes (Expression ).visit (it ):
118+ if it .dim .root is sparse_dim and expr in FindNodes (Expression ).visit (it ):
145119 return it
146120 return None
147121
148122
123+ def _enclosing_void_halospot (iet , nest ):
124+ """
125+ Return the HaloSpot directly wrapping ``nest`` if it carries an
126+ empty (void) HaloScheme, otherwise None. Such HaloSpots are leftover
127+ after ``_drop_reduction_halospots`` cleared all entries.
128+ """
129+ for hs in FindNodes (HaloSpot ).visit (iet ):
130+ if hs .is_void and nest in FindNodes (Iteration ).visit (hs ):
131+ return hs
132+ return None
133+
134+
149135def _materialise_nest (nest , exprs ):
150136 """
151137 Rewrite the sparse Dimension's Iteration body to compute the
@@ -154,45 +140,38 @@ def _materialise_nest(nest, exprs):
154140 pattern. Multiple sparse-op Expressions sharing the same outer
155141 Iteration are materialised in one pass and reuse the same temps.
156142 """
157- interp = exprs [0 ].expr . interpolator
158- sample_lse = exprs [ 0 ]. expr
143+ sample = exprs [0 ].expr
144+ interp = sample . interpolator
159145
160146 # Position + coefficient temporaries as IET Expressions. These are
161147 # the same for every Expression in the group, so we emit them once.
162- temps = interp ._sparse_temps (
163- sample_lse .kind , _user_expr (sample_lse ),
164- field = _user_field (sample_lse ),
165- implicit_dims = sample_lse .implicit_dims ,
166- )
148+ field = sample .lhs .function if sample .kind == 'inject' else None
149+ temps = interp ._sparse_temps (sample .kind , sample .rhs , field = field ,
150+ implicit_dims = sample .implicit_dims )
167151 temp_exprs = tuple (Expression (DummyEq (e .lhs , e .rhs ))
168152 for e in lower_exprs (temps ))
169153
170- # For each interpolation Expression in the group, build its
171- # accumulator-wrapped radius nest. Injection Exprs are left where
172- # they are in the radius nest (their Inc is already the right
173- # form); injection Exprs share a single copy of the radius nest.
174- inner = _drop_outer (nest )
154+ # The radius nest is what runs once per sparse point. For each
155+ # interpolation Expression in the group, build its
156+ # accumulator-wrapped copy of the radius nest. Injection Exprs
157+ # share a single copy of the radius nest (their ``Inc`` already
158+ # carries the right ``weights * rhs`` form).
159+ inner = nest .nodes [0 ] if len (nest .nodes ) == 1 else List (body = nest .nodes )
175160 interp_exprs = [e for e in exprs if e .expr .kind == 'interpolate' ]
176161 inject_exprs = [e for e in exprs if e .expr .kind == 'inject' ]
177162
178163 body = []
179164 for expr in interp_exprs :
180- # Build the per-interpolation accumulator: substitute siblings
181- # out and replace ``expr`` with the increment in a single
182- # Transformer pass so the radius sub-tree contains only the
183- # head's increment.
184- body .append (_interp_inner_block (inner , expr , expr .expr .interpolator ,
185- siblings = [e for e in exprs if e is not expr ]))
165+ siblings = [e for e in exprs if e is not expr ]
166+ body .append (_interp_inner_block (inner , expr , expr .expr .interpolator , siblings ))
186167 if inject_exprs :
187- # Injections share one radius nest with no interpolation heads.
188- others = {e : None for e in interp_exprs }
189- local_inner = Transformer (others , nested = True ).visit (inner ) if others else inner
190- body .append (local_inner )
168+ drop = {e : None for e in interp_exprs }
169+ body .append (Transformer (drop , nested = True ).visit (inner ) if drop else inner )
191170
192171 return nest ._rebuild (nodes = temp_exprs + tuple (body ))
193172
194173
195- def _interp_inner_block (inner , expr , interp , siblings = () ):
174+ def _interp_inner_block (inner , expr , interp , siblings ):
196175 """
197176 Build the accumulator/radius/write-back triple for an interpolation:
198177
@@ -232,8 +211,7 @@ def _interp_inner_block(inner, expr, interp, siblings=()):
232211 for rd in weights .free_symbols
233212 if getattr (rd , 'is_Conditional' , False ) and rd .name in rdims_concrete
234213 })
235- weights_expr = lower_exprs (_make_eq (acc , weights )).rhs
236- weighted_rhs = weights_expr * rhs
214+ weighted_rhs = lower_exprs (Eq (acc , weights )).rhs * rhs
237215
238216 init = Expression (DummyEq (acc , 0 ))
239217 inc = Increment (DummyEq (acc , weighted_rhs ))
@@ -249,15 +227,14 @@ def _interp_inner_block(inner, expr, interp, siblings=()):
249227
250228 radius_root = _find_radius_root (inner , interp .sfunction )
251229 if radius_root is None or radius_root is inner :
252- # No intermediate Iteration: wrap the whole ``inner`` directly.
253230 return List (body = (init ,
254231 Transformer (mapper , nested = True ).visit (inner ),
255232 write_back ))
256233
257234 # Wrap the accumulator pattern around just the radius sub-tree,
258235 # leaving the enclosing non-radius Iterations in place.
259- new_radius = Transformer (mapper , nested = True ).visit (radius_root )
260- wrapped = List ( body = ( init , new_radius , write_back ))
236+ wrapped = List ( body = ( init , Transformer (mapper , nested = True ).visit (radius_root ),
237+ write_back ))
261238 return Transformer ({radius_root : wrapped }, nested = True ).visit (inner )
262239
263240
@@ -273,36 +250,3 @@ def _find_radius_root(inner, sfunction):
273250 if it .dim .name .startswith (prefix ):
274251 return it
275252 return None
276-
277-
278- def _drop_outer (nest ):
279- """
280- Return the sub-IET inside ``nest`` (the Iteration over the sparse
281- Dim) — i.e. the radius nest. ``nest.nodes`` is what runs once per
282- sparse point.
283- """
284- if len (nest .nodes ) == 1 :
285- return nest .nodes [0 ]
286- return List (body = nest .nodes )
287-
288-
289- def _make_eq (lhs , rhs ):
290- """Helper to wrap a (lhs, rhs) pair as a symbolic Eq for ``lower_exprs``."""
291- return Eq (lhs , rhs )
292-
293-
294- def _efunc_prefix (lse ):
295- """Pick an ElementalFunction name prefix based on the sparse-op kind."""
296- return f'{ lse .kind } _{ lse .interpolator .sfunction .name } '
297-
298-
299- def _user_expr (lse ):
300- """The user-side expression to feed ``_sparse_temps`` (rhs of the LSE)."""
301- return lse .rhs
302-
303-
304- def _user_field (lse ):
305- """For injection, the destination Function appearing in lhs."""
306- if lse .kind == 'inject' :
307- return lse .lhs .function
308- return None
0 commit comments