11from collections import defaultdict , namedtuple
22from functools import cached_property
3- from itertools import chain
3+ from itertools import chain , groupby
44
55import numpy as np
6- from sympy import S , simplify
6+ from sympy import Mod , S , simplify
77
88from devito .exceptions import CompilationError
99from devito .ir import (
@@ -116,7 +116,7 @@ def key(f):
116116 # Then we inject them into the Clusters. This involves creating the
117117 # initializing Clusters, and replacing the buffered Functions with the buffers
118118 clusters = InjectBuffers (mapper , sregistry , options ).process (clusters )
119-
119+ print ( clusters )
120120 return clusters
121121
122122
@@ -142,14 +142,18 @@ def callback(self, clusters, prefix):
142142 return clusters
143143 d = prefix [- 1 ].dim
144144
145- key = lambda f , * args : f in self .mapper
145+ def key (f , * args ):
146+ for (ff , _ ) in self .mapper :
147+ if f == ff :
148+ return True
149+ return False
146150 bfmap = map_buffered_functions (clusters , key )
147151
148152 # A BufferDescriptor is a simple data structure storing additional
149153 # information about a buffer, harvested from the subset of `clusters`
150154 # that access it
151- descriptors = {b : BufferDescriptor (f , b , bfmap [f ])
152- for f , b in self .mapper .items ()
155+ descriptors = {b : BufferDescriptor (f , b , bfmap [f ], g )
156+ for ( f , g ) , b in self .mapper .items ()
153157 if f in bfmap }
154158
155159 # Are we inside the right `d`?
@@ -184,6 +188,8 @@ def callback(self, clusters, prefix):
184188 continue
185189 if c not in v .firstread :
186190 continue
191+ if not c .guards .get (d ) == v .guards .get (d ):
192+ continue
187193
188194 idxf = v .last_idx [c ]
189195 idxb = mds [(v .xd , idxf )]
@@ -203,7 +209,7 @@ def callback(self, clusters, prefix):
203209 guards = c .guards
204210
205211 properties = c .properties .sequentialize (d )
206- if not isinstance (d , BufferDimension ):
212+ if not isinstance (d , BufferDimension ) and c . guards [ d ]. has ( Mod ) :
207213 properties = properties .prefetchable (d )
208214 # `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
209215 properties = properties .parallelize (v .bdims ).affine (v .bdims )
@@ -227,6 +233,8 @@ def callback(self, clusters, prefix):
227233 continue
228234 if c not in v .lastwrite :
229235 continue
236+ if not c .guards .get (d ) == v .guards .get (d ):
237+ continue
230238
231239 idxf = v .last_idx [c ]
232240 idxb = mds [(v .xd , idxf )]
@@ -358,59 +366,60 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
358366 xds = {}
359367 mapper = {}
360368 for f , clusters in bfmap .items ():
361- exprs = flatten (c .exprs for c in clusters )
362-
363- bdims = key (f , exprs )
364-
365- dims = [d for d in f .dimensions if d not in bdims ]
366- if len (dims ) != 1 :
367- raise CompilationError (f"Unsupported multi-dimensional `buffering` "
368- f"required by `{ f } `" )
369- dim = dims .pop ()
370-
371- if is_buffering (exprs ):
372- # Multi-level buffering
373- # NOTE: a bit rudimentary (we could go through the exprs one by one
374- # instead), but it's much shorter this way
375- buffers = [f for f in retrieve_functions (exprs ) if f .is_Array ]
376- assert len (buffers ) == 1 , "Unexpected form of multi-level buffering"
377- buffer , = buffers
378- xd = buffer .indices [dim ]
379- else :
380- size = infer_buffer_size (f , dim , clusters )
381-
382- if async_degree is not None :
383- if async_degree < size :
384- warning (
385- 'Ignoring provided asynchronous degree as it would be '
386- f'too small for the required buffer (provided { async_degree } , '
387- f'but need at least { size } for `{ f .name } `)'
388- )
389- else :
390- size = async_degree
391-
392- # A special CustomDimension to use in place of `dim` in the buffer
393- try :
394- xd = xds [(dim , size )]
395- except KeyError :
396- name = sregistry .make_name (prefix = 'db' )
397- xd = xds [(dim , size )] = BufferDimension (name , 0 , size - 1 , size , dim )
398-
399- # The buffer dimensions
400- dimensions = list (f .dimensions )
401- assert dim in f .dimensions
402- dimensions [dimensions .index (dim )] = xd
403-
404- # Finally create the actual buffer
405- cls = callback or Array
406- name = sregistry .make_name (prefix = f'{ f .name } b' )
407- # We specify the padding to match the input Function's one, so that
408- # the array can be used in place of the Function with valid strides
409- # Plain Array do not track mapped so we default to no padding
410- padding = 0 if cls is Array else f .padding
411- mapper [f ] = cls (name = name , dimensions = dimensions , dtype = f .dtype ,
412- padding = padding , grid = f .grid , halo = f .halo ,
413- space = 'mapped' , mapped = f , f = f )
369+ for k , ck in groupby (clusters , key = lambda c : c .guards ):
370+ exprs = flatten (c .exprs for c in ck )
371+
372+ bdims = key (f , exprs )
373+
374+ dims = [d for d in f .dimensions if d not in bdims ]
375+ if len (dims ) != 1 :
376+ raise CompilationError (f"Unsupported multi-dimensional `buffering` "
377+ f"required by `{ f } `" )
378+ dim = dims .pop ()
379+
380+ if is_buffering (exprs ):
381+ # Multi-level buffering
382+ # NOTE: a bit rudimentary (we could go through the exprs one by one
383+ # instead), but it's much shorter this way
384+ buffers = [f for f in retrieve_functions (exprs ) if f .is_Array ]
385+ assert len (buffers ) == 1 , "Unexpected form of multi-level buffering"
386+ buffer , = buffers
387+ xd = buffer .indices [dim ]
388+ else :
389+ size = infer_buffer_size (f , dim , clusters )
390+
391+ if async_degree is not None :
392+ if async_degree < size :
393+ warning (
394+ 'Ignoring provided asynchronous degree as it would be '
395+ f'too small for the required buffer (provided { async_degree } , '
396+ f'but need at least { size } for `{ f .name } `)'
397+ )
398+ else :
399+ size = async_degree
400+
401+ # A special CustomDimension to use in place of `dim` in the buffer
402+ try :
403+ xd = xds [(dim , size , k )]
404+ except KeyError :
405+ name = sregistry .make_name (prefix = 'db' )
406+ xd = xds [(dim , size , k )] = BufferDimension (name , 0 , size - 1 , size , dim )
407+
408+ # The buffer dimensions
409+ dimensions = list (f .dimensions )
410+ assert dim in f .dimensions
411+ dimensions [dimensions .index (dim )] = xd
412+
413+ # Finally create the actual buffer
414+ cls = callback or Array
415+ name = sregistry .make_name (prefix = f'{ f .name } b' )
416+ # We specify the padding to match the input Function's one, so that
417+ # the array can be used in place of the Function with valid strides
418+ # Plain Array do not track mapped so we default to no padding
419+ padding = 0 if cls is Array else f .padding
420+ mapper [(f , k )] = cls (name = name , dimensions = dimensions , dtype = f .dtype ,
421+ padding = padding , grid = f .grid , halo = f .halo ,
422+ space = 'mapped' , mapped = f , f = f )
414423
415424 return mapper
416425
@@ -430,10 +439,11 @@ def map_buffered_functions(clusters, key):
430439
431440class BufferDescriptor :
432441
433- def __init__ (self , f , b , clusters ):
442+ def __init__ (self , f , b , clusters , guards ):
434443 self .f = f
435444 self .b = b
436445 self .clusters = clusters
446+ self .guards = guards
437447
438448 self .xd , = b .find (BufferDimension )
439449 self .bdims = tuple (d for d in b .dimensions if d is not self .xd )
@@ -674,8 +684,9 @@ def make_mds(descriptors, prefix, sregistry):
674684 # same strategy is also applied in clusters/algorithms/Stepper
675685 key = lambda i : - np .inf if i - p == 0 else (i - p ) # noqa: B023
676686 indices = sorted (v .indices , key = key )
687+ v_mds = None
677688
678- for i in indices :
689+ for k , i in enumerate ( indices ) :
679690 k = (v .xd , i )
680691 if k in mds :
681692 continue
0 commit comments