@@ -804,11 +804,10 @@ class SparseFirst(SparseFunction):
804804 rec = s .interpolate (expr = s + fs , implicit_dims = grid .stepping_dim )
805805 op = Operator (eqs + rec )
806806
807- # The sparse interp now lowers to a Call into an
808- # ElementalFunction, so the parent carries that Call in
809- # addition to any halo exchanges.
810- calls = FindNodes (Call ).visit (op )
811- assert any (c .name .startswith ('interpolate_' ) for c in calls )
807+ # Generated code: one halo exchange for ``fs`` and one Call
808+ # to the ``interpolate_s0`` ElementalFunction.
809+ call_names = [c .name for c in FindNodes (Call ).visit (op )]
810+ assert call_names == ['haloupdate0' , 'interpolate_s0' ]
812811
813812 op (time_M = 10 )
814813 expected = 10 * 11 / 2 # n (n+1)/2
@@ -838,10 +837,10 @@ class CoordSlowSparseFunction(SparseFunction):
838837
839838 op = Operator ([Eq (u , 1 )] + rec_eq )
840839
841- # Expected Calls : one halo exchange + one interpolate efunc.
842- call_names = sorted ( c . name for c in FindNodes ( Call ). visit ( op ))
843- assert any ( n . startswith ( 'haloupdate' ) for n in call_names )
844- assert any ( n . startswith ( 'interpolate_' ) for n in call_names )
840+ # Generated code : one halo exchange for ``u`` and one Call to
841+ # the ``interpolate_s0`` ElementalFunction.
842+ call_names = [ c . name for c in FindNodes ( Call ). visit ( op )]
843+ assert call_names == [ 'haloupdate0' , 'interpolate_s0' ]
845844
846845 op .apply ()
847846 assert np .all (s .data == 1 )
@@ -870,10 +869,10 @@ class CoordSlowSparseFunction(SparseTimeFunction):
870869
871870 op = Operator ([Eq (u , 1 )] + rec_eq )
872871
873- # Expected Calls : one halo exchange + one interpolate efunc.
874- call_names = sorted ( c . name for c in FindNodes ( Call ). visit ( op ))
875- assert any ( n . startswith ( 'haloupdate' ) for n in call_names )
876- assert any ( n . startswith ( 'interpolate_' ) for n in call_names )
872+ # Generated code : one halo exchange for ``u`` and one Call to
873+ # the ``interpolate_s0`` ElementalFunction.
874+ call_names = [ c . name for c in FindNodes ( Call ). visit ( op )]
875+ assert call_names == [ 'haloupdate0' , 'interpolate_s0' ]
877876
878877 op .apply (time_M = 5 )
879878 assert np .all (s .data == 1 )
0 commit comments