Skip to content

Commit b997156

Browse files
committed
compiler: fix various corner case of multi buffering
1 parent 91266de commit b997156

5 files changed

Lines changed: 65 additions & 21 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def __new__(cls, *args, **kwargs):
237237
cond = d.relation(cond, GuardFactor(d))
238238
conditionals[d] = cond
239239

240+
# Replace the ConditionalDimensions in `expr`
241+
for d, cond in conditionals.items():
242+
# Replace dimension with index
243+
index = d.index
244+
index = index - relational_min(cond, d.parent)
245+
shift = relational_shift(cond, d.parent)
246+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
247+
240248
# Merge conditionals when possible. E.g if we have an implicit_dim
241249
# and there is a dimension with the same parent, we ca merged
242250
# its condition
@@ -247,19 +255,13 @@ def __new__(cls, *args, **kwargs):
247255
if cd.parent == d.parent and cd != d:
248256
cond = conditionals.pop(d)
249257
mode = cd.relation and d.relation
250-
conditionals[cd] = mode(cond, conditionals[cd])
258+
if issubclass(mode, sympy.Or):
259+
conditionals[d] = cond
260+
conditionals.pop(cd)
261+
else:
262+
conditionals[cd] = mode(cond, conditionals[cd])
251263
break
252264

253-
conditionals = frozendict(conditionals)
254-
255-
# Replace the ConditionalDimensions in `expr`
256-
for d, cond in conditionals.items():
257-
# Replace dimension with index
258-
index = d.index
259-
index = index - relational_min(cond, d.parent)
260-
shift = relational_shift(cond, d.parent)
261-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
262-
263265
# Lower all Differentiable operations into SymPy operations
264266
rhs = diff2sympy(expr.rhs)
265267

devito/passes/clusters/buffering.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from itertools import chain
44

55
import numpy as np
6-
from sympy import S, simplify
6+
from sympy import Mod, S, simplify
77

88
from devito.exceptions import CompilationError
99
from devito.ir import (
@@ -203,7 +203,7 @@ def callback(self, clusters, prefix):
203203
guards = c.guards
204204

205205
properties = c.properties.sequentialize(d)
206-
if not isinstance(d, BufferDimension):
206+
if not isinstance(d, BufferDimension) and c.guards[d].has(Mod):
207207
properties = properties.prefetchable(d)
208208
# `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
209209
properties = properties.parallelize(v.bdims).affine(v.bdims)
@@ -377,7 +377,12 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
377377
buffer, = buffers
378378
xd = buffer.indices[dim]
379379
else:
380-
size = infer_buffer_size(f, dim, clusters)
380+
if len({c.guards[dim.root] for c in clusters}) > 1:
381+
# Multiple clusters with different guards,
382+
# will lead to conflicts in asynchrony with multiple (modulo) slots
383+
size = 1
384+
else:
385+
size = infer_buffer_size(f, dim, clusters)
381386

382387
if async_degree is not None:
383388
if async_degree < size:

devito/symbolics/extended_sympy.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from devito.types import Symbol
2020
from devito.types.basic import Basic
21-
from devito.types.relational import Ge
2221

2322
__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa
2423
'LeftShift', 'RightShift', 'IntDiv', 'Terminal', 'CallFromPointer',
@@ -48,11 +47,6 @@ def canonical(self):
4847
def negated(self):
4948
return CondNe(*self.args, evaluate=False)
5049

51-
@property
52-
def _as_min(self):
53-
from devito.symbolics.extended_dtypes import INT
54-
return INT(Ge(*self.args))
55-
5650

5751
class CondNe(sympy.Ne):
5852

devito/types/relational.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,4 +320,5 @@ def _(expr, s):
320320
def _(expr, s):
321321
if isinstance(expr.lhs, sympy.Mod):
322322
return 0
323-
return expr._as_min
323+
from devito.symbolics.extended_dtypes import INT
324+
return INT(Ge(*expr.args))

tests/test_buffering.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,45 @@ def test_multi_cond_v1():
818818
for i in range(nt-1):
819819
assert np.allclose(f.data[i], i*2)
820820
assert np.allclose(f.data[nt-1], ntmod - 2)
821+
822+
823+
@pytest.mark.parametrize("factor", [1, 2, 3])
824+
def test_buffering_multi_cond(factor):
825+
grid = Grid((16, 16))
826+
827+
nt = 5
828+
ntmod = (nt - 1) * factor + 1
829+
830+
ct0 = ConditionalDimension(name="ct0", parent=grid.time_dim, factor=factor,
831+
relation=Or)
832+
f = TimeFunction(grid=grid, name='f', time_order=0, space_order=0,
833+
time_dim=ct0, save=nt)
834+
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)
835+
836+
eqs = []
837+
eqs.append(Eq(T, grid.time_dim))
838+
839+
# conditional dimension for the last sample in the operator
840+
ctend = ConditionalDimension(name="ctend", parent=grid.time_dim,
841+
condition=CondEq(grid.time_dim, ntmod - 2),
842+
relation=Or)
843+
844+
eqs.append(Eq(f, T)) # this to save times from 0 to nt - 2
845+
# this to save the last time sample nt - 1
846+
eqs.append(Eq(f.forward, T+1, implicit_dims=ctend))
847+
848+
# run operator with serialization
849+
op = Operator(eqs, opt='buffering')
850+
op.apply(time_m=0, time_M=ntmod-2)
851+
852+
# Now run backward as well with buffering
853+
854+
f_all = TimeFunction(grid=grid, name='f_all', time_order=0,
855+
space_order=0, time_dim=ct0, save=nt)
856+
857+
eq_all = [Eq(f_all, f)]
858+
eq_all.append(Eq(f_all.forward, f.forward, implicit_dims=ctend))
859+
op_all = Operator(eq_all, opt='buffering')
860+
op_all.apply(time_m=0, time_M=ntmod-2)
861+
862+
assert np.allclose(f_all.data[:, 11, 11], factor * np.arange(nt))

0 commit comments

Comments
 (0)