@@ -43,21 +43,26 @@ class Annotations:
4343
4444 unroll_factor : int | None = None
4545 unroll_specified : bool = False
46- vectorize : bool = False
47- parallelize : bool = False
46+ vectorize : bool | str = False
47+ parallelize : bool | str = False
48+ partial : bool = False
49+ full : bool = False
4850
4951
5052@dataclass (frozen = True )
5153class SplitDecl :
5254 """AST Type: a split declaration like 'axis[start:end]'."""
5355
5456 axis : str
55- start : int | None
56- end : int | None
57+ start : int | str | None
58+ end : int | str | None
5759 body : ScheduleSpec
60+ size : int | str | None = None
5861
5962 @override
6063 def __str__ (self ) -> str :
64+ if self .size is not None :
65+ return f"{ self .axis } [:{ self .size } :]"
6166 start_str = "" if self .start is None else str (self .start )
6267 end_str = "" if self .end is None else str (self .end )
6368 decl = f"{ self .axis } [{ start_str } :{ end_str } ]"
@@ -69,7 +74,7 @@ class TileDecl:
6974 """AST Type: a tile declaration like 'axis#size'."""
7075
7176 axis : str
72- size : int
77+ size : int | str
7378 annotations : Annotations
7479
7580 @override
@@ -85,7 +90,36 @@ class AxisDecl:
8590 annotations : Annotations
8691
8792
88- ScheduleItem = SplitDecl | TileDecl | AxisDecl
93+ @dataclass (frozen = True )
94+ class FusionDecl :
95+ """AST Type: a fusion declaration"""
96+
97+
98+ @dataclass (frozen = True )
99+ class PackDecl :
100+ """AST Type: a packing declaration"""
101+
102+ param : str | bool
103+ input : str
104+ pad : str | bool
105+
106+
107+ @dataclass (frozen = True )
108+ class BufferDecl :
109+ """AST Type: a bufferisation declaration"""
110+
111+ param : str | bool
112+ pad : str
113+
114+
115+ @dataclass (frozen = True )
116+ class ExploreDecl :
117+ level : str
118+
119+
120+ ScheduleItem = (
121+ SplitDecl | TileDecl | AxisDecl | FusionDecl | PackDecl | BufferDecl | ExploreDecl
122+ )
89123
90124
91125@dataclass (frozen = True )
@@ -144,10 +178,12 @@ def _parse_tile(self, declaration: str, value: dict) -> TileDecl:
144178
145179 axis_name , size_str = parts
146180
147- try :
148- size = int (size_str )
149- except ValueError :
150- raise ScheduleParseError (f"`{ declaration } `: { size_str } is not an integer." )
181+ size = int (size_str ) if size_str .isnumeric () else size_str
182+
183+ # try:
184+ # size = int(size_str)
185+ # except ValueError:
186+ # raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.")
151187
152188 annotations = self ._parse_annotations (value , declaration )
153189 return TileDecl (axis = axis_name , size = size , annotations = annotations )
@@ -165,23 +201,39 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations
165201 unroll_specified = False
166202 vectorize = False
167203 parallelize = False
204+ partial = False
205+ full = False
168206
169207 for key , param in value .items ():
170208 if key == "unroll" :
171209 unroll_factor = param
172210 unroll_specified = True
173211 elif key == "vectorize" :
174- if param is not None :
212+ if isinstance (param , str ):
213+ vectorize = param
214+ elif param is not None :
175215 raise ScheduleParseError (
176216 f'`{{"vectorize" = { param } }}`: parameterized vectorization not implemented.'
177217 )
178- vectorize = True
218+ else :
219+ vectorize = True
179220 elif key == "parallelize" :
180- if param is not None :
221+ if isinstance (param , str ):
222+ parallelize = param
223+ elif param is not None :
181224 raise ScheduleParseError (
182225 f'`{{"parallelize" = { param } }}`: parameterized parallelization not implemented.'
183226 )
184- parallelize = True
227+ else :
228+ parallelize = True
229+ elif key == "partial" :
230+ if full :
231+ raise ScheduleParseError ("Tile cannot be full and partial." )
232+ partial = True
233+ elif key == "full" :
234+ if partial :
235+ raise ScheduleParseError ("Tile cannot be partial and full." )
236+ full = True
185237 else :
186238 raise ScheduleParseError (f"Unknown annotation on { context } : { key } " )
187239
@@ -190,6 +242,8 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations
190242 unroll_specified = unroll_specified ,
191243 vectorize = vectorize ,
192244 parallelize = parallelize ,
245+ partial = partial ,
246+ full = full ,
193247 )
194248
195249 def _parse_split_syntax (
@@ -224,8 +278,8 @@ def _interpret_spec(
224278 slice = loop_nest .build_slice (root )
225279
226280 # Track state during interpretation
227- sizes : dict [str , int ] = {}
228- previous_cut : dict [str , int | None ] = {a : 0 for a in self .abstract_axis }
281+ sizes : dict [str , int | str ] = {}
282+ previous_cut : dict [str , int | str | None ] = {a : 0 for a in self .abstract_axis }
229283 interchange : list [str ] = list (head )
230284
231285 for item in spec .items :
@@ -257,7 +311,7 @@ def _interpret_split(
257311 loop_nest : LoopNest ,
258312 root : str ,
259313 interchange : list [str ],
260- previous_cut : dict [str , int | None ],
314+ previous_cut : dict [str , int | str | None ],
261315 ) -> None :
262316 """Interpret a split declaration."""
263317 axis_name = item .axis
@@ -273,10 +327,8 @@ def _interpret_split(
273327 # it is the previous cut
274328 if x is None :
275329 x = cut
276- assert x is not None
277-
278330 self ._check_splitting_intervals (item , cut , x )
279-
331+ assert x is not None
280332 # Update the previous cut
281333 previous_cut [axis_name ] = y
282334
@@ -298,12 +350,13 @@ def _interpret_tile(
298350 item : TileDecl ,
299351 slice : LoopNestSlice ,
300352 interchange : list [str ],
301- sizes : dict [str , int ],
353+ sizes : dict [str , int | str ],
302354 ) -> str :
303355 """Interpret a tile declaration. Returns the loop name."""
304356 self ._check_axis_existence (item .axis )
305357 tile_num = len (slice .tiles [item .axis ])
306358 loop_name = f"{ item .axis } { tile_num } "
359+ assert isinstance (item .size , int )
307360 if item .size <= 0 :
308361 raise ScheduleInterpretError (
309362 f"`{ item } `: tile sizes should be strictly positive."
@@ -344,7 +397,7 @@ def _apply_annotations(
344397 self ,
345398 annotations : Annotations ,
346399 loop_name : str ,
347- sizes : dict [str , int ],
400+ sizes : dict [str , int | str ],
348401 slice : LoopNestSlice ,
349402 ) -> None :
350403 """Apply annotations to a loop in the slice."""
@@ -357,7 +410,7 @@ def _apply_annotations(
357410 f"{ loop_name } 's size being unknown, an unroll factor is needed."
358411 )
359412 unroll_factor = sizes [loop_name ]
360- elif unroll_factor <= 0 :
413+ elif isinstance ( unroll_factor , int ) and unroll_factor <= 0 :
361414 raise ScheduleInterpretError (
362415 f'`{{"unroll" = { unroll_factor } }}`: unroll parameter should be strictly positive.'
363416 )
@@ -372,27 +425,46 @@ def _apply_annotations(
372425 def _check_splitting_intervals (
373426 self ,
374427 item : SplitDecl ,
375- cut : int | None ,
376- x : int ,
377- ) -> None :
428+ cut : int | str | None ,
429+ x : int | str | None ,
430+ ) -> int | str | None :
378431 """Check that split intervals are valid and contiguous."""
379-
432+ y = item . end
380433 if cut is None :
381434 raise ScheduleInterpretError (f"{ item } : { item .axis } already covered." )
382435
383- if x > cut :
384- raise ScheduleInterpretError (
385- f"{ item } : splitting doesn't fully cover { item .axis } (jumps from { cut } to { x } )."
386- )
387- elif x < cut :
436+ if x is None :
388437 raise ScheduleInterpretError (
389- f"{ item } : the segment begins at { x } but the previous one ends at { cut } ."
438+ f"x is None, but cut: { cut } is not, this should be unreachable ."
390439 )
440+ if isinstance (x , int ) and isinstance (cut , int ):
441+ if x > cut :
442+ raise ScheduleInterpretError (
443+ f"{ item } : splitting doesn't fully cover { item .axis } (jumps from { cut } to { x } )."
444+ )
445+ elif x < cut :
446+ raise ScheduleInterpretError (
447+ f"{ item } : the segment begins at { x } but the previous one ends at { cut } ."
448+ )
449+ else :
450+ if x != cut :
451+ raise ScheduleInterpretError (
452+ f"{ item } : Splitting ends at { cut } and begins at { x } . These need to be the same."
453+ )
454+ if y is None :
455+ return None
391456
392- if item .end is not None and x >= item .end :
393- raise ScheduleInterpretError (
394- f"{ item } : the ending point should be greater than the starting point."
395- )
457+ if isinstance (x , int ):
458+ if isinstance (y , int ):
459+ if x >= y :
460+ raise ScheduleInterpretError (
461+ f"{ item } : the ending point should be greater than the starting point."
462+ )
463+ else :
464+ return y - x
465+ if x == 0 :
466+ return y
467+ return None
396468
397469
398470@dataclass
@@ -496,6 +568,34 @@ def tiles_to_sizes(self) -> dict[str, int]:
496568 tiles_to_sizes [loop ] = size
497569 return tiles_to_sizes
498570
571+ @property
572+ def int_tiles (self ) -> dict [str , dict [str , int ]]:
573+ return self ._int_dict (self .tiles )
574+
575+ @property
576+ def int_splits (self ) -> dict [str , dict [str , int ]]:
577+ return self ._int_dict (self .splits )
578+
579+ @property
580+ def int_unroll (self ) -> dict [str , int ]:
581+ out = {}
582+ for x , v in self .unroll .items ():
583+ if isinstance (v , str ) and v .isnumeric ():
584+ v = int (v )
585+ assert isinstance (v , int )
586+ out [x ] = v
587+ return out
588+
589+ def _int_dict (self , input : dict [str , dict [str , Any ]]) -> dict [str , dict [str , int ]]:
590+ out : dict [str , dict [str , int ]] = {}
591+ for x , v in input .items ():
592+ v_dict : dict [str , int ] = {}
593+ for x_v , v_v in v .items ():
594+ assert isinstance (v_v , int )
595+ v_dict [x_v ] = v_v
596+ out [x ] = v_dict
597+ return out
598+
499599
500600@dataclass
501601class LoopNest :
@@ -606,6 +706,7 @@ def _check_sizes(self):
606706
607707 if loop_name in sched .unroll :
608708 unroll_factor = sched .unroll [loop_name ]
709+ assert isinstance (unroll_factor , int )
609710 if loop_size and loop_size < unroll_factor :
610711 raise ScheduleValidationError (
611712 f'`{{"unroll" = { unroll_factor } }}`: unroll factor should be smaller than { loop_size } .'
@@ -640,19 +741,11 @@ def descript_scheduler(
640741 abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]).
641742 spec: The schedule specification as a nested dict.
642743 """
643- descript = Descript (scheduler = scheduler , abstract_axis = abstract_axis )
644- descript .apply (node_name = node_name , spec = spec )
645-
646-
647- def correct_type (d : dict [str , int | str ]) -> dict [str , int ]:
648- out_d : dict [str , int ] = {}
649- for k , v in d .items ():
650- assert isinstance (v , int )
651- out_d [k ] = v
652- return out_d
744+ descript = Descript (abstract_axis = abstract_axis )
745+ descript .apply (scheduler = scheduler , node_name = node_name , spec = spec )
653746
654747
655- @dataclass (frozen = True )
748+ @dataclass (frozen = False )
656749class Descript :
657750 """Applies a parsed and interpreted schedule to a Scheduler.
658751
@@ -664,10 +757,11 @@ class Descript:
664757 4. Apply: LoopNest -> Scheduler
665758 """
666759
667- scheduler : Scheduler
668760 abstract_axis : list [str ]
669761
670- def apply (self , node_name : str , spec : dict [str , dict [str , Any ]]) -> None :
762+ def apply (
763+ self , node_name : str , spec : dict [str , dict [str , Any ]], scheduler : Scheduler
764+ ) -> None :
671765 """Parse, interpret, validate, and apply a schedule specification.
672766
673767 Args:
@@ -691,22 +785,22 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
691785 loop_nest .check ()
692786
693787 # Apply the schedule to the scheduler
694- self ._apply_loop_nest (loop_nest )
788+ self ._apply_loop_nest (loop_nest , scheduler )
695789
696- def _apply_loop_nest (self , loop_nest : LoopNest ) -> None :
790+ def _apply_loop_nest (self , loop_nest : LoopNest , scheduler : Scheduler ) -> None :
697791 """Apply a LoopNest to the scheduler."""
698- self . scheduler .set_dims (self .abstract_axis )
792+ scheduler .set_dims (self .abstract_axis )
699793
700794 for slice in loop_nest .slices :
701795 root = slice .root
702796
703- for d , s in slice .splits .items ():
704- self . scheduler .split (d , s , root = root )
797+ for d , s in slice .int_splits .items ():
798+ scheduler .split (d , s , root = root )
705799
706- for d , s in slice .tiles .items ():
707- self . scheduler .tile (d , s , root = root )
800+ for d , s in slice .int_tiles .items ():
801+ scheduler .tile (d , s , root = root )
708802
709- self . scheduler .interchange (slice .interchange , root = root )
710- self . scheduler .vectorize (slice .vectorize , root = root )
711- self . scheduler .parallelize (slice .parallelize , root = root )
712- self . scheduler .unroll (slice .unroll , root = root )
803+ scheduler .interchange (slice .interchange , root = root )
804+ scheduler .vectorize (slice .vectorize , root = root )
805+ scheduler .parallelize (slice .parallelize , root = root )
806+ scheduler .unroll (slice .int_unroll , root = root )
0 commit comments