Skip to content

Commit 516047a

Browse files
committed
compiler: support mutli-buffering
1 parent c8ed5e1 commit 516047a

6 files changed

Lines changed: 330 additions & 189 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,17 @@ def guard(clusters):
259259
# Separate out the indirect ConditionalDimensions, which only serve
260260
# the purpose of protecting from OOB accesses
261261
cds = [d for d in cds if not d.indirect]
262+
modes = [cd.relation for cd in cds]
263+
if len({m == 'strict' for m in modes}) > 1:
264+
raise CompilationError("Only one `strict` condition"
265+
"can be used in an equation")
266+
elif 'strict' in modes:
267+
mode = 'strict'
268+
else:
269+
mode = sympy.And if sympy.And in modes else sympy.Or
262270

263271
# Chain together all `cds` conditions from all expressions in `c`
264272
guards = {}
265-
mode = sympy.Or
266273
for cd in cds:
267274
# `BOTTOM` parent implies a guard that lives outside of
268275
# any iteration space, which corresponds to the placeholder None
@@ -279,7 +286,6 @@ def guard(clusters):
279286

280287
# Pull `cd` from any expr
281288
condition = guards.setdefault(k, [])
282-
mode = mode and cd.relation
283289
for e in exprs:
284290
try:
285291
condition.append(e.conditionals[cd])
@@ -296,7 +302,10 @@ def guard(clusters):
296302

297303
# Combination `mode` is And by default.
298304
# If all conditions are Or then Or combination `mode` is used.
299-
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
305+
if mode == 'strict':
306+
guards = {d: v[0] for d, v in guards.items()}
307+
else:
308+
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
300309

301310
# Construct a guarded Cluster
302311
processed.append(c.rebuild(exprs=exprs, guards=guards))

devito/ir/equations/equation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ def __new__(cls, *args, **kwargs):
241241
for d, cond in conditionals.items():
242242
# Replace dimension with index
243243
index = d.index
244-
index = index - relational_min(cond, d.parent)
244+
if d.condition is not None and d in expr.free_symbols:
245+
index = index - relational_min(cond, d.parent)
245246
shift = relational_shift(cond, d.parent)
246247
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
247248

@@ -254,11 +255,11 @@ def __new__(cls, *args, **kwargs):
254255
for cd in dict(conditionals):
255256
if cd.parent == d.parent and cd != d:
256257
cond = conditionals.pop(d)
257-
mode = cd.relation and d.relation
258-
if issubclass(mode, sympy.Or):
258+
if d.relation == 'strict':
259259
conditionals[d] = cond
260260
conditionals.pop(cd)
261261
else:
262+
mode = cd.relation and d.relation
262263
conditionals[cd] = mode(cond, conditionals[cd])
263264
break
264265

devito/passes/clusters/asynchrony.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22

3-
from sympy import true
3+
from sympy import Mod, true
44

55
from devito.ir import (
66
Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray,
@@ -78,7 +78,8 @@ def callback(self, clusters, prefix):
7878
d = self.key0(c0)
7979
if d is not dim:
8080
continue
81-
81+
if d in c0.guards and not c0.guards[d].has(Mod):
82+
continue
8283
protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs)
8384
self._schedule_withlocks(c0, d, protected, locks, syncs)
8485

0 commit comments

Comments
 (0)