Skip to content

Commit c8ed5e1

Browse files
committed
compiler: fix various corner case of multi buffering
1 parent 91266de commit c8ed5e1

6 files changed

Lines changed: 132 additions & 82 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def __new__(cls, *args, **kwargs):
237237
cond = d.relation(cond, GuardFactor(d))
238238
conditionals[d] = cond
239239

240+
# Replace the ConditionalDimensions in `expr`
241+
for d, cond in conditionals.items():
242+
# Replace dimension with index
243+
index = d.index
244+
index = index - relational_min(cond, d.parent)
245+
shift = relational_shift(cond, d.parent)
246+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
247+
240248
# Merge conditionals when possible. E.g if we have an implicit_dim
241249
# and there is a dimension with the same parent, we ca merged
242250
# its condition
@@ -247,19 +255,13 @@ def __new__(cls, *args, **kwargs):
247255
if cd.parent == d.parent and cd != d:
248256
cond = conditionals.pop(d)
249257
mode = cd.relation and d.relation
250-
conditionals[cd] = mode(cond, conditionals[cd])
258+
if issubclass(mode, sympy.Or):
259+
conditionals[d] = cond
260+
conditionals.pop(cd)
261+
else:
262+
conditionals[cd] = mode(cond, conditionals[cd])
251263
break
252264

253-
conditionals = frozendict(conditionals)
254-
255-
# Replace the ConditionalDimensions in `expr`
256-
for d, cond in conditionals.items():
257-
# Replace dimension with index
258-
index = d.index
259-
index = index - relational_min(cond, d.parent)
260-
shift = relational_shift(cond, d.parent)
261-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
262-
263265
# Lower all Differentiable operations into SymPy operations
264266
rhs = diff2sympy(expr.rhs)
265267

devito/passes/clusters/asynchrony.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
193193
if c.properties.is_prefetchable(d._defines):
194194
_actions_from_update_memcpy(c, d, clusters, actions, sregistry)
195195
elif d.is_Custom and is_integer(c.ispace[d].size):
196-
_actions_from_init(c, d, actions)
196+
_actions_from_init(c, d, clusters, actions)
197197

198198
# Attach the computed Actions
199199
processed = []
@@ -214,7 +214,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
214214
return processed
215215

216216

217-
def _actions_from_init(c, d, actions):
217+
def _actions_from_init(c, d, clusters, actions):
218218
e = c.exprs[0]
219219
function = e.rhs.function
220220
target = e.lhs.function

devito/passes/clusters/buffering.py

Lines changed: 73 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from collections import defaultdict, namedtuple
22
from functools import cached_property
3-
from itertools import chain
3+
from itertools import chain, groupby
44

55
import numpy as np
6-
from sympy import S, simplify
6+
from sympy import Mod, S, simplify
77

88
from devito.exceptions import CompilationError
99
from devito.ir import (
@@ -116,7 +116,7 @@ def key(f):
116116
# Then we inject them into the Clusters. This involves creating the
117117
# initializing Clusters, and replacing the buffered Functions with the buffers
118118
clusters = InjectBuffers(mapper, sregistry, options).process(clusters)
119-
119+
print(clusters)
120120
return clusters
121121

122122

@@ -142,14 +142,18 @@ def callback(self, clusters, prefix):
142142
return clusters
143143
d = prefix[-1].dim
144144

145-
key = lambda f, *args: f in self.mapper
145+
def key(f, *args):
146+
for (ff, _) in self.mapper:
147+
if f == ff:
148+
return True
149+
return False
146150
bfmap = map_buffered_functions(clusters, key)
147151

148152
# A BufferDescriptor is a simple data structure storing additional
149153
# information about a buffer, harvested from the subset of `clusters`
150154
# that access it
151-
descriptors = {b: BufferDescriptor(f, b, bfmap[f])
152-
for f, b in self.mapper.items()
155+
descriptors = {b: BufferDescriptor(f, b, bfmap[f], g)
156+
for (f, g), b in self.mapper.items()
153157
if f in bfmap}
154158

155159
# Are we inside the right `d`?
@@ -184,6 +188,8 @@ def callback(self, clusters, prefix):
184188
continue
185189
if c not in v.firstread:
186190
continue
191+
if not c.guards.get(d) == v.guards.get(d):
192+
continue
187193

188194
idxf = v.last_idx[c]
189195
idxb = mds[(v.xd, idxf)]
@@ -203,7 +209,7 @@ def callback(self, clusters, prefix):
203209
guards = c.guards
204210

205211
properties = c.properties.sequentialize(d)
206-
if not isinstance(d, BufferDimension):
212+
if not isinstance(d, BufferDimension) and c.guards[d].has(Mod):
207213
properties = properties.prefetchable(d)
208214
# `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
209215
properties = properties.parallelize(v.bdims).affine(v.bdims)
@@ -227,6 +233,8 @@ def callback(self, clusters, prefix):
227233
continue
228234
if c not in v.lastwrite:
229235
continue
236+
if not c.guards.get(d) == v.guards.get(d):
237+
continue
230238

231239
idxf = v.last_idx[c]
232240
idxb = mds[(v.xd, idxf)]
@@ -358,59 +366,60 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
358366
xds = {}
359367
mapper = {}
360368
for f, clusters in bfmap.items():
361-
exprs = flatten(c.exprs for c in clusters)
362-
363-
bdims = key(f, exprs)
364-
365-
dims = [d for d in f.dimensions if d not in bdims]
366-
if len(dims) != 1:
367-
raise CompilationError(f"Unsupported multi-dimensional `buffering` "
368-
f"required by `{f}`")
369-
dim = dims.pop()
370-
371-
if is_buffering(exprs):
372-
# Multi-level buffering
373-
# NOTE: a bit rudimentary (we could go through the exprs one by one
374-
# instead), but it's much shorter this way
375-
buffers = [f for f in retrieve_functions(exprs) if f.is_Array]
376-
assert len(buffers) == 1, "Unexpected form of multi-level buffering"
377-
buffer, = buffers
378-
xd = buffer.indices[dim]
379-
else:
380-
size = infer_buffer_size(f, dim, clusters)
381-
382-
if async_degree is not None:
383-
if async_degree < size:
384-
warning(
385-
'Ignoring provided asynchronous degree as it would be '
386-
f'too small for the required buffer (provided {async_degree}, '
387-
f'but need at least {size} for `{f.name}`)'
388-
)
389-
else:
390-
size = async_degree
391-
392-
# A special CustomDimension to use in place of `dim` in the buffer
393-
try:
394-
xd = xds[(dim, size)]
395-
except KeyError:
396-
name = sregistry.make_name(prefix='db')
397-
xd = xds[(dim, size)] = BufferDimension(name, 0, size-1, size, dim)
398-
399-
# The buffer dimensions
400-
dimensions = list(f.dimensions)
401-
assert dim in f.dimensions
402-
dimensions[dimensions.index(dim)] = xd
403-
404-
# Finally create the actual buffer
405-
cls = callback or Array
406-
name = sregistry.make_name(prefix=f'{f.name}b')
407-
# We specify the padding to match the input Function's one, so that
408-
# the array can be used in place of the Function with valid strides
409-
# Plain Array do not track mapped so we default to no padding
410-
padding = 0 if cls is Array else f.padding
411-
mapper[f] = cls(name=name, dimensions=dimensions, dtype=f.dtype,
412-
padding=padding, grid=f.grid, halo=f.halo,
413-
space='mapped', mapped=f, f=f)
369+
for k, ck in groupby(clusters, key=lambda c: c.guards):
370+
exprs = flatten(c.exprs for c in ck)
371+
372+
bdims = key(f, exprs)
373+
374+
dims = [d for d in f.dimensions if d not in bdims]
375+
if len(dims) != 1:
376+
raise CompilationError(f"Unsupported multi-dimensional `buffering` "
377+
f"required by `{f}`")
378+
dim = dims.pop()
379+
380+
if is_buffering(exprs):
381+
# Multi-level buffering
382+
# NOTE: a bit rudimentary (we could go through the exprs one by one
383+
# instead), but it's much shorter this way
384+
buffers = [f for f in retrieve_functions(exprs) if f.is_Array]
385+
assert len(buffers) == 1, "Unexpected form of multi-level buffering"
386+
buffer, = buffers
387+
xd = buffer.indices[dim]
388+
else:
389+
size = infer_buffer_size(f, dim, clusters)
390+
391+
if async_degree is not None:
392+
if async_degree < size:
393+
warning(
394+
'Ignoring provided asynchronous degree as it would be '
395+
f'too small for the required buffer (provided {async_degree}, '
396+
f'but need at least {size} for `{f.name}`)'
397+
)
398+
else:
399+
size = async_degree
400+
401+
# A special CustomDimension to use in place of `dim` in the buffer
402+
try:
403+
xd = xds[(dim, size, k)]
404+
except KeyError:
405+
name = sregistry.make_name(prefix='db')
406+
xd = xds[(dim, size, k)] = BufferDimension(name, 0, size-1, size, dim)
407+
408+
# The buffer dimensions
409+
dimensions = list(f.dimensions)
410+
assert dim in f.dimensions
411+
dimensions[dimensions.index(dim)] = xd
412+
413+
# Finally create the actual buffer
414+
cls = callback or Array
415+
name = sregistry.make_name(prefix=f'{f.name}b')
416+
# We specify the padding to match the input Function's one, so that
417+
# the array can be used in place of the Function with valid strides
418+
# Plain Array do not track mapped so we default to no padding
419+
padding = 0 if cls is Array else f.padding
420+
mapper[(f, k)] = cls(name=name, dimensions=dimensions, dtype=f.dtype,
421+
padding=padding, grid=f.grid, halo=f.halo,
422+
space='mapped', mapped=f, f=f)
414423

415424
return mapper
416425

@@ -430,10 +439,11 @@ def map_buffered_functions(clusters, key):
430439

431440
class BufferDescriptor:
432441

433-
def __init__(self, f, b, clusters):
442+
def __init__(self, f, b, clusters, guards):
434443
self.f = f
435444
self.b = b
436445
self.clusters = clusters
446+
self.guards = guards
437447

438448
self.xd, = b.find(BufferDimension)
439449
self.bdims = tuple(d for d in b.dimensions if d is not self.xd)
@@ -674,8 +684,9 @@ def make_mds(descriptors, prefix, sregistry):
674684
# same strategy is also applied in clusters/algorithms/Stepper
675685
key = lambda i: -np.inf if i - p == 0 else (i - p) # noqa: B023
676686
indices = sorted(v.indices, key=key)
687+
v_mds = None
677688

678-
for i in indices:
689+
for k, i in enumerate(indices):
679690
k = (v.xd, i)
680691
if k in mds:
681692
continue

devito/symbolics/extended_sympy.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from devito.types import Symbol
2020
from devito.types.basic import Basic
21-
from devito.types.relational import Ge
2221

2322
__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa
2423
'LeftShift', 'RightShift', 'IntDiv', 'Terminal', 'CallFromPointer',
@@ -48,11 +47,6 @@ def canonical(self):
4847
def negated(self):
4948
return CondNe(*self.args, evaluate=False)
5049

51-
@property
52-
def _as_min(self):
53-
from devito.symbolics.extended_dtypes import INT
54-
return INT(Ge(*self.args))
55-
5650

5751
class CondNe(sympy.Ne):
5852

devito/types/relational.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,4 +320,5 @@ def _(expr, s):
320320
def _(expr, s):
321321
if isinstance(expr.lhs, sympy.Mod):
322322
return 0
323-
return expr._as_min
323+
from devito.symbolics.extended_dtypes import INT
324+
return INT(Ge(*expr.args))

tests/test_buffering.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,45 @@ def test_multi_cond_v1():
818818
for i in range(nt-1):
819819
assert np.allclose(f.data[i], i*2)
820820
assert np.allclose(f.data[nt-1], ntmod - 2)
821+
822+
823+
@pytest.mark.parametrize("factor", [1, 2, 3])
824+
def test_buffering_multi_cond(factor):
825+
grid = Grid((16, 16))
826+
827+
nt = 5
828+
ntmod = (nt - 1) * factor + 1
829+
830+
ct0 = ConditionalDimension(name="ct0", parent=grid.time_dim, factor=factor,
831+
relation=Or)
832+
f = TimeFunction(grid=grid, name='f', time_order=0, space_order=0,
833+
time_dim=ct0, save=nt)
834+
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)
835+
836+
eqs = []
837+
eqs.append(Eq(T, grid.time_dim))
838+
839+
# conditional dimension for the last sample in the operator
840+
ctend = ConditionalDimension(name="ctend", parent=grid.time_dim,
841+
condition=CondEq(grid.time_dim, ntmod - 2),
842+
relation=Or)
843+
844+
eqs.append(Eq(f, T)) # this to save times from 0 to nt - 2
845+
# this to save the last time sample nt - 1
846+
eqs.append(Eq(f.forward, T+1, implicit_dims=ctend))
847+
848+
# run operator with serialization
849+
op = Operator(eqs, opt='buffering')
850+
op.apply(time_m=0, time_M=ntmod-2)
851+
852+
# Now run backward as well with buffering
853+
854+
f_all = TimeFunction(grid=grid, name='f_all', time_order=0,
855+
space_order=0, time_dim=ct0, save=nt)
856+
857+
eq_all = [Eq(f_all, f)]
858+
eq_all.append(Eq(f_all.forward, f.forward, implicit_dims=ctend))
859+
op_all = Operator(eq_all, opt='buffering')
860+
op_all.apply(time_m=0, time_M=ntmod-2)
861+
862+
assert np.allclose(f_all.data[:, 11, 11], factor * np.arange(nt))

0 commit comments

Comments
 (0)