@@ -818,3 +818,45 @@ def test_multi_cond_v1():
818818 for i in range (nt - 1 ):
819819 assert np .allclose (f .data [i ], i * 2 )
820820 assert np .allclose (f .data [nt - 1 ], ntmod - 2 )
821+
822+
823+ @pytest .mark .parametrize ("factor" , [1 , 2 , 3 ])
824+ def test_buffering_multi_cond (factor ):
825+ grid = Grid ((16 , 16 ))
826+
827+ nt = 5
828+ ntmod = (nt - 1 ) * factor + 1
829+
830+ ct0 = ConditionalDimension (name = "ct0" , parent = grid .time_dim , factor = factor ,
831+ relation = Or )
832+ f = TimeFunction (grid = grid , name = 'f' , time_order = 0 , space_order = 0 ,
833+ time_dim = ct0 , save = nt )
834+ T = TimeFunction (grid = grid , name = 'T' , time_order = 0 , space_order = 0 )
835+
836+ eqs = []
837+ eqs .append (Eq (T , grid .time_dim ))
838+
839+ # conditional dimension for the last sample in the operator
840+ ctend = ConditionalDimension (name = "ctend" , parent = grid .time_dim ,
841+ condition = CondEq (grid .time_dim , ntmod - 2 ),
842+ relation = Or )
843+
844+ eqs .append (Eq (f , T )) # this to save times from 0 to nt - 2
845+ # this to save the last time sample nt - 1
846+ eqs .append (Eq (f .forward , T + 1 , implicit_dims = ctend ))
847+
848+ # run operator with serialization
849+ op = Operator (eqs , opt = 'buffering' )
850+ op .apply (time_m = 0 , time_M = ntmod - 2 )
851+
852+ # Now run backward as well with buffering
853+
854+ f_all = TimeFunction (grid = grid , name = 'f_all' , time_order = 0 ,
855+ space_order = 0 , time_dim = ct0 , save = nt )
856+
857+ eq_all = [Eq (f_all , f )]
858+ eq_all .append (Eq (f_all .forward , f .forward , implicit_dims = ctend ))
859+ op_all = Operator (eq_all , opt = 'buffering' )
860+ op_all .apply (time_m = 0 , time_M = ntmod - 2 )
861+
862+ assert np .allclose (f_all .data [:, 11 , 11 ], factor * np .arange (nt ))
0 commit comments