Skip to content

Commit 3c0ddea

Browse files
committed
tests: Tighten sparse-op Call assertions to exact names
1 parent ea8ad37 commit 3c0ddea

3 files changed

Lines changed: 15 additions & 18 deletions

File tree

tests/test_dle.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,8 @@ def test_cache_blocking_structure_distributed(mode):
180180
# The sparse inject lives in its own efunc, which sits between the
181181
# two dense Eq's. Because they can no longer fuse across the
182182
# sparse-op Call, each dense Eq lands in its own MPI compute efunc.
183-
compute_names = sorted(n for n in op._func_table if n.startswith('compute'))
184-
bns0, _ = assert_blocking(op._func_table[compute_names[0]].root, {'x0_blk0'})
185-
bns1, _ = assert_blocking(op._func_table[compute_names[1]].root, {'x1_blk0'})
183+
bns0, _ = assert_blocking(op._func_table['compute0'].root, {'x0_blk0'})
184+
bns1, _ = assert_blocking(op._func_table['compute1'].root, {'x1_blk0'})
186185

187186
for blk_dim, bns in [('x0_blk0', bns0), ('x1_blk0', bns1)]:
188187
iters = FindNodes(Iteration).visit(bns[blk_dim])

tests/test_interpolation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,8 +995,7 @@ def test_interp_complex_and_real(self, dtype):
995995
# Both interpolations land in the same sparse-op efunc since they
996996
# share the `p_sc` Dimension (sce reuses sc's coordinates); two
997997
# radius nests sit side-by-side inside the single ``p_sc`` loop.
998-
[efunc_name] = [n for n in opC._func_table if n.startswith('interpolate_')]
999-
efunc = opC._func_table[efunc_name].root
998+
efunc = opC._func_table['interpolate_sc0'].root
1000999
assert_structure(
10011000
efunc,
10021001
['p_sc', 'p_sc,rp_scx,rp_scy,rp_scz', 'p_sc,rp_scx,rp_scy,rp_scz'],

tests/test_mpi.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -804,11 +804,10 @@ class SparseFirst(SparseFunction):
804804
rec = s.interpolate(expr=s+fs, implicit_dims=grid.stepping_dim)
805805
op = Operator(eqs + rec)
806806

807-
# The sparse interp now lowers to a Call into an
808-
# ElementalFunction, so the parent carries that Call in
809-
# addition to any halo exchanges.
810-
calls = FindNodes(Call).visit(op)
811-
assert any(c.name.startswith('interpolate_') for c in calls)
807+
# Generated code: one halo exchange for ``fs`` and one Call
808+
# to the ``interpolate_s0`` ElementalFunction.
809+
call_names = [c.name for c in FindNodes(Call).visit(op)]
810+
assert call_names == ['haloupdate0', 'interpolate_s0']
812811

813812
op(time_M=10)
814813
expected = 10*11/2 # n (n+1)/2
@@ -838,10 +837,10 @@ class CoordSlowSparseFunction(SparseFunction):
838837

839838
op = Operator([Eq(u, 1)] + rec_eq)
840839

841-
# Expected Calls: one halo exchange + one interpolate efunc.
842-
call_names = sorted(c.name for c in FindNodes(Call).visit(op))
843-
assert any(n.startswith('haloupdate') for n in call_names)
844-
assert any(n.startswith('interpolate_') for n in call_names)
840+
# Generated code: one halo exchange for ``u`` and one Call to
841+
# the ``interpolate_s0`` ElementalFunction.
842+
call_names = [c.name for c in FindNodes(Call).visit(op)]
843+
assert call_names == ['haloupdate0', 'interpolate_s0']
845844

846845
op.apply()
847846
assert np.all(s.data == 1)
@@ -870,10 +869,10 @@ class CoordSlowSparseFunction(SparseTimeFunction):
870869

871870
op = Operator([Eq(u, 1)] + rec_eq)
872871

873-
# Expected Calls: one halo exchange + one interpolate efunc.
874-
call_names = sorted(c.name for c in FindNodes(Call).visit(op))
875-
assert any(n.startswith('haloupdate') for n in call_names)
876-
assert any(n.startswith('interpolate_') for n in call_names)
872+
# Generated code: one halo exchange for ``u`` and one Call to
873+
# the ``interpolate_s0`` ElementalFunction.
874+
call_names = [c.name for c in FindNodes(Call).visit(op)]
875+
assert call_names == ['haloupdate0', 'interpolate_s0']
877876

878877
op.apply(time_M=5)
879878
assert np.all(s.data == 1)

0 commit comments

Comments
 (0)