Skip to content

Commit 2dfe71e

Browse files
author
Leon Frenot
committed
Fixes after more rebasing
1 parent df4dd0a commit 2dfe71e

15 files changed

Lines changed: 853 additions & 887 deletions

src/xtc/schedules/descript.py

Lines changed: 155 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,26 @@ class Annotations:
4343

4444
unroll_factor: int | None = None
4545
unroll_specified: bool = False
46-
vectorize: bool = False
47-
parallelize: bool = False
46+
vectorize: bool | str = False
47+
parallelize: bool | str = False
48+
partial: bool = False
49+
full: bool = False
4850

4951

5052
@dataclass(frozen=True)
5153
class SplitDecl:
5254
"""AST Type: a split declaration like 'axis[start:end]'."""
5355

5456
axis: str
55-
start: int | None
56-
end: int | None
57+
start: int | str | None
58+
end: int | str | None
5759
body: ScheduleSpec
60+
size: int | str | None = None
5861

5962
@override
6063
def __str__(self) -> str:
64+
if self.size is not None:
65+
return f"{self.axis}[:{self.size}:]"
6166
start_str = "" if self.start is None else str(self.start)
6267
end_str = "" if self.end is None else str(self.end)
6368
decl = f"{self.axis}[{start_str}:{end_str}]"
@@ -69,7 +74,7 @@ class TileDecl:
6974
"""AST Type: a tile declaration like 'axis#size'."""
7075

7176
axis: str
72-
size: int
77+
size: int | str
7378
annotations: Annotations
7479

7580
@override
@@ -85,7 +90,36 @@ class AxisDecl:
8590
annotations: Annotations
8691

8792

88-
ScheduleItem = SplitDecl | TileDecl | AxisDecl
93+
@dataclass(frozen=True)
94+
class FusionDecl:
95+
"""AST Type: a fusion declaration"""
96+
97+
98+
@dataclass(frozen=True)
99+
class PackDecl:
100+
"""AST Type: a packing declaration"""
101+
102+
param: str | bool
103+
input: str
104+
pad: str | bool
105+
106+
107+
@dataclass(frozen=True)
108+
class BufferDecl:
109+
"""AST Type: a bufferisation declaration"""
110+
111+
param: str | bool
112+
pad: str
113+
114+
115+
@dataclass(frozen=True)
116+
class ExploreDecl:
117+
level: str
118+
119+
120+
ScheduleItem = (
121+
SplitDecl | TileDecl | AxisDecl | FusionDecl | PackDecl | BufferDecl | ExploreDecl
122+
)
89123

90124

91125
@dataclass(frozen=True)
@@ -144,10 +178,12 @@ def _parse_tile(self, declaration: str, value: dict) -> TileDecl:
144178

145179
axis_name, size_str = parts
146180

147-
try:
148-
size = int(size_str)
149-
except ValueError:
150-
raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.")
181+
size = int(size_str) if size_str.isnumeric() else size_str
182+
183+
# try:
184+
# size = int(size_str)
185+
# except ValueError:
186+
# raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.")
151187

152188
annotations = self._parse_annotations(value, declaration)
153189
return TileDecl(axis=axis_name, size=size, annotations=annotations)
@@ -165,23 +201,39 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations
165201
unroll_specified = False
166202
vectorize = False
167203
parallelize = False
204+
partial = False
205+
full = False
168206

169207
for key, param in value.items():
170208
if key == "unroll":
171209
unroll_factor = param
172210
unroll_specified = True
173211
elif key == "vectorize":
174-
if param is not None:
212+
if isinstance(param, str):
213+
vectorize = param
214+
elif param is not None:
175215
raise ScheduleParseError(
176216
f'`{{"vectorize" = {param}}}`: parameterized vectorization not implemented.'
177217
)
178-
vectorize = True
218+
else:
219+
vectorize = True
179220
elif key == "parallelize":
180-
if param is not None:
221+
if isinstance(param, str):
222+
parallelize = param
223+
elif param is not None:
181224
raise ScheduleParseError(
182225
f'`{{"parallelize" = {param}}}`: parameterized parallelization not implemented.'
183226
)
184-
parallelize = True
227+
else:
228+
parallelize = True
229+
elif key == "partial":
230+
if full:
231+
raise ScheduleParseError("Tile cannot be full and partial.")
232+
partial = True
233+
elif key == "full":
234+
if partial:
235+
raise ScheduleParseError("Tile cannot be partial and full.")
236+
full = True
185237
else:
186238
raise ScheduleParseError(f"Unknown annotation on {context}: {key}")
187239

@@ -190,6 +242,8 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations
190242
unroll_specified=unroll_specified,
191243
vectorize=vectorize,
192244
parallelize=parallelize,
245+
partial=partial,
246+
full=full,
193247
)
194248

195249
def _parse_split_syntax(
@@ -224,8 +278,8 @@ def _interpret_spec(
224278
slice = loop_nest.build_slice(root)
225279

226280
# Track state during interpretation
227-
sizes: dict[str, int] = {}
228-
previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis}
281+
sizes: dict[str, int | str] = {}
282+
previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis}
229283
interchange: list[str] = list(head)
230284

231285
for item in spec.items:
@@ -257,7 +311,7 @@ def _interpret_split(
257311
loop_nest: LoopNest,
258312
root: str,
259313
interchange: list[str],
260-
previous_cut: dict[str, int | None],
314+
previous_cut: dict[str, int | str | None],
261315
) -> None:
262316
"""Interpret a split declaration."""
263317
axis_name = item.axis
@@ -273,10 +327,8 @@ def _interpret_split(
273327
# it is the previous cut
274328
if x is None:
275329
x = cut
276-
assert x is not None
277-
278330
self._check_splitting_intervals(item, cut, x)
279-
331+
assert x is not None
280332
# Update the previous cut
281333
previous_cut[axis_name] = y
282334

@@ -298,12 +350,13 @@ def _interpret_tile(
298350
item: TileDecl,
299351
slice: LoopNestSlice,
300352
interchange: list[str],
301-
sizes: dict[str, int],
353+
sizes: dict[str, int | str],
302354
) -> str:
303355
"""Interpret a tile declaration. Returns the loop name."""
304356
self._check_axis_existence(item.axis)
305357
tile_num = len(slice.tiles[item.axis])
306358
loop_name = f"{item.axis}{tile_num}"
359+
assert isinstance(item.size, int)
307360
if item.size <= 0:
308361
raise ScheduleInterpretError(
309362
f"`{item}`: tile sizes should be strictly positive."
@@ -344,7 +397,7 @@ def _apply_annotations(
344397
self,
345398
annotations: Annotations,
346399
loop_name: str,
347-
sizes: dict[str, int],
400+
sizes: dict[str, int | str],
348401
slice: LoopNestSlice,
349402
) -> None:
350403
"""Apply annotations to a loop in the slice."""
@@ -357,7 +410,7 @@ def _apply_annotations(
357410
f"{loop_name}'s size being unknown, an unroll factor is needed."
358411
)
359412
unroll_factor = sizes[loop_name]
360-
elif unroll_factor <= 0:
413+
elif isinstance(unroll_factor, int) and unroll_factor <= 0:
361414
raise ScheduleInterpretError(
362415
f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.'
363416
)
@@ -372,27 +425,46 @@ def _apply_annotations(
372425
def _check_splitting_intervals(
373426
self,
374427
item: SplitDecl,
375-
cut: int | None,
376-
x: int,
377-
) -> None:
428+
cut: int | str | None,
429+
x: int | str | None,
430+
) -> int | str | None:
378431
"""Check that split intervals are valid and contiguous."""
379-
432+
y = item.end
380433
if cut is None:
381434
raise ScheduleInterpretError(f"{item}: {item.axis} already covered.")
382435

383-
if x > cut:
384-
raise ScheduleInterpretError(
385-
f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})."
386-
)
387-
elif x < cut:
436+
if x is None:
388437
raise ScheduleInterpretError(
389-
f"{item}: the segment begins at {x} but the previous one ends at {cut}."
438+
f"x is None, but cut: {cut} is not, this should be unreachable."
390439
)
440+
if isinstance(x, int) and isinstance(cut, int):
441+
if x > cut:
442+
raise ScheduleInterpretError(
443+
f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})."
444+
)
445+
elif x < cut:
446+
raise ScheduleInterpretError(
447+
f"{item}: the segment begins at {x} but the previous one ends at {cut}."
448+
)
449+
else:
450+
if x != cut:
451+
raise ScheduleInterpretError(
452+
f"{item}: Splitting ends at {cut} and begins at {x}. These need to be the same."
453+
)
454+
if y is None:
455+
return None
391456

392-
if item.end is not None and x >= item.end:
393-
raise ScheduleInterpretError(
394-
f"{item}: the ending point should be greater than the starting point."
395-
)
457+
if isinstance(x, int):
458+
if isinstance(y, int):
459+
if x >= y:
460+
raise ScheduleInterpretError(
461+
f"{item}: the ending point should be greater than the starting point."
462+
)
463+
else:
464+
return y - x
465+
if x == 0:
466+
return y
467+
return None
396468

397469

398470
@dataclass
@@ -496,6 +568,34 @@ def tiles_to_sizes(self) -> dict[str, int]:
496568
tiles_to_sizes[loop] = size
497569
return tiles_to_sizes
498570

571+
@property
572+
def int_tiles(self) -> dict[str, dict[str, int]]:
573+
return self._int_dict(self.tiles)
574+
575+
@property
576+
def int_splits(self) -> dict[str, dict[str, int]]:
577+
return self._int_dict(self.splits)
578+
579+
@property
580+
def int_unroll(self) -> dict[str, int]:
581+
out = {}
582+
for x, v in self.unroll.items():
583+
if isinstance(v, str) and v.isnumeric():
584+
v = int(v)
585+
assert isinstance(v, int)
586+
out[x] = v
587+
return out
588+
589+
def _int_dict(self, input: dict[str, dict[str, Any]]) -> dict[str, dict[str, int]]:
590+
out: dict[str, dict[str, int]] = {}
591+
for x, v in input.items():
592+
v_dict: dict[str, int] = {}
593+
for x_v, v_v in v.items():
594+
assert isinstance(v_v, int)
595+
v_dict[x_v] = v_v
596+
out[x] = v_dict
597+
return out
598+
499599

500600
@dataclass
501601
class LoopNest:
@@ -606,6 +706,7 @@ def _check_sizes(self):
606706

607707
if loop_name in sched.unroll:
608708
unroll_factor = sched.unroll[loop_name]
709+
assert isinstance(unroll_factor, int)
609710
if loop_size and loop_size < unroll_factor:
610711
raise ScheduleValidationError(
611712
f'`{{"unroll" = {unroll_factor}}}`: unroll factor should be smaller than {loop_size}.'
@@ -640,19 +741,11 @@ def descript_scheduler(
640741
abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]).
641742
spec: The schedule specification as a nested dict.
642743
"""
643-
descript = Descript(scheduler=scheduler, abstract_axis=abstract_axis)
644-
descript.apply(node_name=node_name, spec=spec)
645-
646-
647-
def correct_type(d: dict[str, int | str]) -> dict[str, int]:
648-
out_d: dict[str, int] = {}
649-
for k, v in d.items():
650-
assert isinstance(v, int)
651-
out_d[k] = v
652-
return out_d
744+
descript = Descript(abstract_axis=abstract_axis)
745+
descript.apply(scheduler=scheduler, node_name=node_name, spec=spec)
653746

654747

655-
@dataclass(frozen=True)
748+
@dataclass(frozen=False)
656749
class Descript:
657750
"""Applies a parsed and interpreted schedule to a Scheduler.
658751
@@ -664,10 +757,11 @@ class Descript:
664757
4. Apply: LoopNest -> Scheduler
665758
"""
666759

667-
scheduler: Scheduler
668760
abstract_axis: list[str]
669761

670-
def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
762+
def apply(
763+
self, node_name: str, spec: dict[str, dict[str, Any]], scheduler: Scheduler
764+
) -> None:
671765
"""Parse, interpret, validate, and apply a schedule specification.
672766
673767
Args:
@@ -691,22 +785,22 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
691785
loop_nest.check()
692786

693787
# Apply the schedule to the scheduler
694-
self._apply_loop_nest(loop_nest)
788+
self._apply_loop_nest(loop_nest, scheduler)
695789

696-
def _apply_loop_nest(self, loop_nest: LoopNest) -> None:
790+
def _apply_loop_nest(self, loop_nest: LoopNest, scheduler: Scheduler) -> None:
697791
"""Apply a LoopNest to the scheduler."""
698-
self.scheduler.set_dims(self.abstract_axis)
792+
scheduler.set_dims(self.abstract_axis)
699793

700794
for slice in loop_nest.slices:
701795
root = slice.root
702796

703-
for d, s in slice.splits.items():
704-
self.scheduler.split(d, s, root=root)
797+
for d, s in slice.int_splits.items():
798+
scheduler.split(d, s, root=root)
705799

706-
for d, s in slice.tiles.items():
707-
self.scheduler.tile(d, s, root=root)
800+
for d, s in slice.int_tiles.items():
801+
scheduler.tile(d, s, root=root)
708802

709-
self.scheduler.interchange(slice.interchange, root=root)
710-
self.scheduler.vectorize(slice.vectorize, root=root)
711-
self.scheduler.parallelize(slice.parallelize, root=root)
712-
self.scheduler.unroll(slice.unroll, root=root)
803+
scheduler.interchange(slice.interchange, root=root)
804+
scheduler.vectorize(slice.vectorize, root=root)
805+
scheduler.parallelize(slice.parallelize, root=root)
806+
scheduler.unroll(slice.int_unroll, root=root)

0 commit comments

Comments
 (0)