Skip to content

Commit d2cbc85

Browse files
committed
compiler: Fix fission via logical dependence
1 parent fef0335 commit d2cbc85

6 files changed

Lines changed: 44 additions & 38 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,17 +228,11 @@ def _break_for_parallelism(self, scope, dim, timestamp):
228228
# Would break a dependence on storage
229229
return False
230230

231-
if any(dep.is_carried(i) for i in candidates):
231+
if any(dep.as_logical.is_carried(i) for i in candidates):
232+
# If, from a semantic viewpoint, `i` is a purely sequential
233+
# Dimension, give up
232234
test0 = dep.is_flow and dep.is_lex_negative
233235
test1 = dep.is_anti and dep.is_lex_positive
234-
if test0:
235-
# If the same access pair is not a flow under logical distance,
236-
# the dep is a buffer/modulo-aliasing artifact and fission is OK
237-
ldist = dep.source.distance(dep.sink, logical=True)
238-
real_flow = (ldist > 0) or \
239-
(ldist == 0 and dep.sink.lex_ge(dep.source))
240-
if not real_flow:
241-
test0 = real_flow
242236
if test0 or test1:
243237
return False
244238

devito/ir/equations/equation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class IREq(sympy.Eq, Pickable):
3030
__rargs__ = ('lhs', 'rhs')
3131
__rkwargs__ = ('ispace', 'conditionals', 'implicit_dims', 'operation')
3232

33+
def _hashable_content(self):
34+
return (*super()._hashable_content(),
35+
*tuple(getattr(self, i) for i in self.__rkwargs__))
36+
3337
@property
3438
def is_Scalar(self):
3539
return self.lhs.is_Symbol
@@ -302,7 +306,7 @@ def __new__(cls, *args, **kwargs):
302306
setattr(expr, f'_{i}', v)
303307
else:
304308
expr._ispace = kwargs['ispace']
305-
expr._conditionals = kwargs.get('conditionals', frozendict())
309+
expr._conditionals = kwargs.get('conditionals', {})
306310
expr._implicit_dims = input_expr.implicit_dims
307311
expr._operation = Operation.detect(input_expr)
308312
elif len(args) == 2:
@@ -313,6 +317,10 @@ def __new__(cls, *args, **kwargs):
313317
else:
314318
raise ValueError(f"Cannot construct ClusterizedEq from args={str(args)} "
315319
f"and kwargs={str(kwargs)}")
320+
321+
# Immutability (and thus hashability, etc)
322+
expr._conditionals = frozendict(expr._conditionals)
323+
316324
return expr
317325

318326
func = IREq._rebuild

devito/ir/support/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,10 @@ def is_iaw(self):
698698
def is_reduction(self):
699699
return self.source.is_reduction or self.sink.is_reduction
700700

701+
@cached_property
702+
def as_logical(self):
703+
return LogicalDependence(self.source, self.sink)
704+
701705
@memoized_meth
702706
def is_const(self, dim):
703707
"""

devito/ir/support/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def detect_io(exprs, relax=False):
278278
if rule(f):
279279
writes.append(f)
280280

281-
return filter_sorted(reads), filter_sorted(writes)
281+
return tuple(filter_sorted(reads)), tuple(filter_sorted(writes))
282282

283283

284284
def pull_dims(exprs, flag=True):

tests/test_dse.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
get_params, skipif
1010
)
1111
from devito import ( # noqa
12-
NODE, Abs, Buffer, ConditionalDimension, Constant, DefaultDimension, Derivative,
13-
Dimension, Eq, Function, Ge, Grid, Inc, Lt, Operator, SparseTimeFunction,
14-
SubDimension, TimeFunction, configuration, cos, dimensions, div, exp,
15-
first_derivative, floor, grad, norm, sin, solve, sqrt, switchconfig, transpose
12+
NODE, Abs, ConditionalDimension, Constant, DefaultDimension, Derivative, Dimension,
13+
Eq, Function, Ge, Grid, Inc, Lt, Operator, SparseTimeFunction, SubDimension,
14+
TimeFunction, configuration, cos, dimensions, div, exp, first_derivative, floor, grad,
15+
norm, sin, solve, sqrt, switchconfig, transpose
1616
)
1717
from devito.exceptions import InvalidArgument, InvalidOperator
1818
from devito.ir import (
@@ -58,26 +58,6 @@ def test_scheduling_after_rewrite():
5858
assert all(trees[1].root.dim is tree.root.dim for tree in trees[1:])
5959

6060

61-
def test_scheduling_no_deriv():
62-
grid = Grid((11, 11, 11))
63-
x, y, z = grid.dimensions
64-
65-
image_vs = Function(name='image_vs', grid=grid, space_order=1, staggered=NODE)
66-
p_back_xy = TimeFunction(name='p_back_xy', grid=grid, staggered=(x, y),
67-
space_order=4, time_order=1, save=Buffer(1))
68-
69-
eqns = [Eq(image_vs, p_back_xy + image_vs),
70-
Eq(p_back_xy.backward, p_back_xy)]
71-
72-
op = Operator(eqns)
73-
74-
assert_structure(
75-
op,
76-
['t,x0_blk0,y0_blk0,x,y,z', 't,x1_blk0,y1_blk0,x,y,z'],
77-
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz'
78-
)
79-
80-
8161
@pytest.mark.parametrize('expr,expected', [
8262
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
8363
('fa[x]**2', 'fa[x]*fa[x]'),

tests/test_fission.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from conftest import assert_structure
44
from devito import (
5-
Buffer, Eq, Function, Grid, Inc, Operator, SubDimension, SubDomain, TimeFunction,
6-
solve
5+
NODE, Buffer, Eq, Function, Grid, Inc, Operator, SubDimension, SubDomain,
6+
TimeFunction, solve
77
)
88
from devito.ir.iet import retrieve_iteration_tree
99
from devito.ir.support.properties import PARALLEL
@@ -131,7 +131,7 @@ def test_issue_1921():
131131
assert np.all(g.data == g1.data)
132132

133133

134-
def test_buffer1_fissioning():
134+
def test_buffer1_v0():
135135
"""
136136
Tests an edge case whereby inability to spot the equivalence of
137137
`f.forward`/`backward` and `f` when using `Buffer(1)` would cause
@@ -196,3 +196,23 @@ def define(self, dimensions):
196196
# Two loop nests: free-surface-like and update-like
197197
assert_structure(op, ['t,x,y,z', 't,x0_blk0,y0_blk0,x,y,z'],
198198
't,x,y,z,x0_blk0,y0_blk0,x,y,z')
199+
200+
201+
def test_buffer1_v1():
202+
grid = Grid((11, 11, 11))
203+
x, y, z = grid.dimensions
204+
205+
image_vs = Function(name='image_vs', grid=grid, space_order=1, staggered=NODE)
206+
p_back_xy = TimeFunction(name='p_back_xy', grid=grid, staggered=(x, y),
207+
space_order=4, time_order=1, save=Buffer(1))
208+
209+
eqns = [Eq(image_vs, p_back_xy + image_vs),
210+
Eq(p_back_xy.backward, p_back_xy)]
211+
212+
op = Operator(eqns)
213+
214+
assert_structure(
215+
op,
216+
['t,x0_blk0,y0_blk0,x,y,z', 't,x1_blk0,y1_blk0,x,y,z'],
217+
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz'
218+
)

0 commit comments

Comments
 (0)