Skip to content

Commit d1c7199

Browse files
Plotting updates & occupancy bug fix
1 parent 8399064 commit d1c7199

4 files changed

Lines changed: 112 additions & 41 deletions

File tree

pytimeloop/fastfusion/mapper/simexplore.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,32 @@ def fuse_sims(sims: list[SIM], resource2capacity: dict=None, return_nmappings_nb
3333
nbuckets.append(len(s))
3434
nmappings.append(sum(len(s2.mapping.data) for s2 in s))
3535
next_and_prev_live_tensors = next_live_tensors | s[0].tensor_names
36+
shared_tensors = set(s[0].tensor_names) & set(ns[0].tensor_names)
3637

3738
first_ns = ns[0]
3839
ns = SIM.group_by_left(ns, s[0].tensor_names)
3940
s = SIM.group_by_right(s, first_ns.tensor_names, keep_loops=True)
40-
41+
4142
for k, ns2 in ns.items():
4243
for ns3 in ns2:
43-
ns3.consolidate(next_live_tensors, resource2capacity)
44+
ns3.consolidate(next_live_tensors, resource2capacity, shared_tensors)
4445
ns[k] = SIM.combine_combineable(ns2, live_tensors)
46+
4547
for k, s2 in s.items():
4648
for s3 in s2:
47-
s3.consolidate(next_live_tensors, resource2capacity)
49+
s3.consolidate(next_live_tensors, resource2capacity, shared_tensors)
4850
s[k] = SIM.combine_combineable(s2, next_and_prev_live_tensors)
51+
52+
# We freed these in the consolidation step
53+
for ns2 in [s, ns]:
54+
for ns3 in ns2.values():
55+
for ns4 in ns3:
56+
for t in list(ns4.tensors):
57+
if t not in next_live_tensors:
58+
del ns4.tensors[t]
4959

5060
DO_PRINT = True
61+
DELAY_MERGE = True
5162

5263
combined: list[SIM] = []
5364
for k in s:
@@ -57,12 +68,19 @@ def fuse_sims(sims: list[SIM], resource2capacity: dict=None, return_nmappings_nb
5768
ns: SIM
5869
if DO_PRINT:
5970
print(f"\t{a.tiling_str()} {a.get_shared_loop_index(live_tensors)} <--> {b.tiling_str()}{b.get_shared_loop_index(next_and_prev_live_tensors)}. ({len(a.mapping.data)})x({len(b.mapping.data)})")
60-
combined.append(a.merge_next(b, next_live_tensors, resource2capacity, delay=True))
71+
if not sims:
72+
print(a.merge_next(b, next_live_tensors, resource2capacity, delay=False))
73+
combined.append(a.merge_next(b, next_live_tensors, resource2capacity, delay=DELAY_MERGE))
6174
elif DO_PRINT:
6275
print(f"\tNo match for {k} ||||||||| {s[k][0].tiling_str()}")
6376

64-
for c, mapping in zip(combined, Parallel(n_jobs=128)(c.mapping for c in combined)):
65-
c.mapping = mapping
77+
if DELAY_MERGE:
78+
for c, mapping in zip(combined, Parallel(n_jobs=128)(c.mapping for c in combined)):
79+
c.mapping = mapping
80+
else:
81+
for c, mapping in zip(combined, (c.mapping for c in combined)):
82+
c.mapping = mapping
83+
6684
print(f"\tCombining {sum(len(s2) for s2 in s)}({len(s)}) x {sum(len(s2) for s2 in ns)}({len(ns)}) -> {len(combined)}")
6785
if DO_PRINT:
6886
for k in ns:

pytimeloop/fastfusion/pareto.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def nameloop2col(name, nloops, left: bool=False):
5959
def is_left_col(x):
6060
return "_LEFT_LEVEL_" in x
6161

62-
MERGE_SUFFIXES = ["_RIGHT_MERGE", "_LEFT_MERGE"]
62+
MERGE_SUFFIXES = ["_LEFT_MERGE", "_RIGHT_MERGE"]
6363

6464
def is_merge_col(c):
6565
return any(c.endswith(s) for s in MERGE_SUFFIXES)
@@ -109,21 +109,23 @@ def makepareto(data: pd.DataFrame) -> pd.DataFrame:
109109
return data
110110
return data[paretoset(data[columns])].reset_index(drop=True)
111111

112-
def squish_left_right(data: pd.DataFrame):
112+
def squish_left_right(data: pd.DataFrame, shared_loop_index: int=None):
113113
nloops2left = defaultdict(set)
114+
dropcols = []
114115
for c in data.columns:
115116
if (name_nloops := col2nameloop(c)) is not None:
116117
if is_left_col(c):
117118
name, nloops = name_nloops
118-
nloops2left[nloops].add((c, name))
119+
if shared_loop_index is None or nloops == shared_loop_index:
120+
nloops2left[nloops].add((c, name))
121+
dropcols.append(c)
119122

120123
for n in nloops2left.keys():
121124
for c, name in nloops2left[n]:
122125
target = nameloop2col(name, n)
123126
max_to_col(data, target, c)
124127

125-
keepcols = [c for c in data.columns if not is_left_col(c)]
126-
return data[keepcols]
128+
return data[[c for c in data.columns if c not in dropcols]]
127129

128130
def free_to_loop_index(data: pd.DataFrame, shared_loop_index: int, skip_pareto: bool=False) -> pd.DataFrame:
129131
nloops2left = defaultdict(set)
@@ -179,6 +181,7 @@ def merge_cross(
179181
as_pareto: bool = False,
180182
) -> pd.DataFrame:
181183
left = free_to_loop_index(left, shared_loop_index + 1)
184+
left = squish_left_right(left, shared_loop_index + 1)
182185
for c in left.columns:
183186
if (name_nloops := col2nameloop(c)) is not None:
184187
if c not in right.columns:
@@ -221,7 +224,6 @@ def merge_cross(
221224
# * Can't bake into compatiblity unless we have a notion of left vs.
222225
# right pipelined.
223226

224-
225227
# PIPELINE CHANGES REQUIRED:
226228
# - Latency above above loop index (first tile), below (all subsequent tiles)
227229
# - Tiling includes information for how may be fused:
@@ -277,6 +279,31 @@ def merge_cross(
277279
# Update the IN_PROGRESS_STATS
278280
for i, r in df[cols].iterrows():
279281
df.at[i, IN_PROGRESS_STATS][last] = r.to_dict()
282+
283+
CHECK_CORRECTNESS = False
284+
if CHECK_CORRECTNESS:
285+
from pytimeloop.fastfusion.plot.looptree import tilings2looptree
286+
df_check = free_to_loop_index(df.copy(), -1, skip_pareto=True)
287+
for i, r in df_check.iterrows():
288+
looptree = tilings2looptree(r[MAPPING], r[STATS], r[TENSORS], r[IN_PROGRESS_STATS], skip_backing_tensors=next_live_tensors)
289+
reservations = dict(looptree.get_reservations())
290+
for k, v in reservations.items():
291+
col = nameloop2col(k, -1)
292+
if col not in df_check.columns:
293+
got = r[[c for c in df_check.columns if col2nameloop(c) is not None]]
294+
raise ValueError(f"Missing {k}: Expected {reservations}. Got: {got}")
295+
if r[col] != v:
296+
got = r[[c for c in df_check.columns if col2nameloop(c) is not None]]
297+
raise ValueError(f"Mismatched {k}: {v} != {r[col]}. Expected {reservations}. Got: {got}")
298+
# import pydot
299+
# graph = pydot.Dot(graph_type="digraph", ranksep="0.2", nodesep="0.2")
300+
# looptree.to_pydot(graph)
301+
# with open(f"test.png", "wb") as f:
302+
# f.write(graph.create_png())
303+
# all_tensors = set(t for tn in r[TENSORS].values() for t in tn)
304+
# for t in sorted(all_tensors):
305+
# print(f"{t.__repr__()},")
306+
280307

281308
# Assert no NaNs
282309
assert not df.isnull().values.any()
@@ -297,8 +324,6 @@ def concat(paretos: list["Pareto"]) -> "Pareto":
297324

298325
def merge(self, other: "Pareto", shared_loop_index: int, next_shared_loop_index: int, resource2capacity: dict[str, int], next_live_tensors: set[int], delay: bool=False) -> "Pareto":
299326
d = delayed(merge_cross)(self.data, other.data, shared_loop_index, next_shared_loop_index, resource2capacity, next_live_tensors=next_live_tensors, as_pareto=True)
300-
if not delay:
301-
print("AHH")
302327
return d if delay else d[0](*d[1], **d[2])
303328

304329
@staticmethod

pytimeloop/fastfusion/plot/looptree.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22
import pydot
3-
from typing import Any
3+
from typing import Any, Iterable
44
from pytimeloop.fastfusion.sim import Tiling, TensorStorage, Loop
55
from pytimeloop.fastfusion.util import expfmt
66
from pytimeloop.fastfusion.pareto import IN_PROGRESS_STATS
@@ -38,7 +38,7 @@ def _to_yaml(self):
3838
def to_yaml(self):
3939
return {"mapping": "fused", "nodes": self._to_yaml()}
4040

41-
def to_pydot(self, graph, parent=None, invisible_root: bool = True):
41+
def to_pydot(self, graph, parent=None, invisible_root: bool = False):
4242
label_lines = []
4343
for t in self.this_level:
4444
label_lines.append(t.pydot_str() if hasattr(t, "pydot_str") else str(t))
@@ -49,7 +49,8 @@ def to_pydot(self, graph, parent=None, invisible_root: bool = True):
4949
node = pydot.Node(id(self), label=node_label, **PYDOT_NODE_DEFAULTS)
5050
graph.add_node(node)
5151
if parent:
52-
graph.add_edge(pydot.Edge(parent, node))
52+
reservations = "\n".join(sorted(f"[{k}] {expfmt(v)}" for k, v in self.get_reservations().items()))
53+
graph.add_edge(pydot.Edge(parent, node, label=reservations))
5354
for child in self.children:
5455
child.to_pydot(graph, node, invisible_root=False)
5556

@@ -59,9 +60,18 @@ def add_stats(self, stats: dict[str, Any]):
5960
else:
6061
for k, v in stats.items():
6162
self.this_level.append(f"{k}: {expfmt(v)}")
63+
64+
def get_reservations(self) -> dict[str, int]:
65+
reservations = defaultdict(lambda: 0)
66+
for c in self.children:
67+
for k, v in c.get_reservations().items():
68+
reservations[k] = max(reservations[k], v)
69+
for t in self.this_level:
70+
if isinstance(t, TensorStorage):
71+
reservations[t.backer_id] += t.tile_size
72+
return reservations
6273

63-
64-
def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], tensors: dict[str, list[TensorStorage]], partial_stats: dict[str, Any]):
74+
def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], tensors: dict[str, list[TensorStorage]], partial_stats: dict[str, Any], skip_backing_tensors: Iterable[str] = ()):
6575
prev_tiling = None
6676
root = Node()
6777
einsum_ids = list(mappings.keys())
@@ -79,27 +89,36 @@ def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], tensors
7989
n.children.append(Node())
8090
n = n.children[-1]
8191
n.children.append(Node()) # Leaf node
82-
for tensor in tiling.tensors:
83-
root.access_level(tensor.above_loop_index).this_level.append(tensor)
92+
id2tensor = defaultdict(lambda: [])
93+
for t in tiling.tensors:
94+
id2tensor[t.tensor_id].append(t)
95+
id2tensor = {k: sorted(v, key=lambda x: (x.above_loop_index, x.backer_id)) for k, v in id2tensor.items()}
96+
for tensor_id, storages in id2tensor.items():
97+
if tensor_id in skip_backing_tensors:
98+
storages = storages[1:]
99+
for tensor in storages:
100+
if tensor not in n.this_level:
101+
root.access_level(tensor.above_loop_index).this_level.append(tensor)
84102
for i, l in enumerate(loops):
85103
root.access_level(index + i + 1).this_level.append(l)
86-
root.add_stats(stats[einsum_id])
87104
last_level = root.access_level(None).this_level
88-
for tensor in tiling.tensors:
89-
if tensor not in last_level:
90-
last_level.append(tensor)
91-
total_resources[tensor.backer_id] += tensor.tile_size
105+
first_level = root.access_level(0).this_level
92106
for tensor in tensors[einsum_id]:
93-
if tensor not in last_level:
94-
last_level.append(tensor.pydot_str() + "**")
95-
total_resources[tensor.backer_id] += tensor.tile_size
107+
if tensor.tensor_id not in skip_backing_tensors:
108+
if tensor not in mappings[einsum_id].tensors:
109+
# tensor = TensorStorage(
110+
# f"*{tensor.tensor_id}",
111+
# tensor.backer_id,
112+
# tensor.above_loop_index,
113+
# tensor.tile_size
114+
# )
115+
first_level.append(tensor)
116+
total_resources[tensor.backer_id] += tensor.tile_size
96117
for k, v in total_resources.items():
97118
last_level.append(f"({k}) TOTAL: {expfmt(v)}")
98-
119+
root.add_stats(stats[einsum_id])
99120
for k, v in partial_stats[einsum_id].items():
100121
last_level.append(f"_PARTIAL {k}: {expfmt(v)}")
101-
102-
103122
prev_tiling = tiling
104123
return root
105124

pytimeloop/fastfusion/sim.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def __str__(self):
5353
return ("S-" if self.is_spatial else "") + f"{self.rank_id}-{self.bound}"
5454

5555
def pydot_str(self):
56-
return f"{self.rank_id} sz {expfmt(self.bound)} {'S' if self.is_spatial else ''} * {expfmt(self.n_repititions)}"
57-
56+
if self.is_spatial:
57+
return f"S-for R{self.rank_id} size {expfmt(self.bound)}"
58+
return f"for {self.rank_id} size {expfmt(self.bound)}"
59+
5860
def rename(self, rank_renaming: dict[str, str], tensor_renaming: dict[str, str]) -> "Loop":
5961
return Loop(rank_renaming[self.rank_id], self.bound, self.is_spatial)
6062

@@ -85,14 +87,14 @@ def ts(self):
8587
return self.tile_size
8688

8789
def __str__(self):
88-
return f"({self.backer_id}) {self.tensor_id} sz {expfmt(self.tile_size)} above {self.above_loop_index}"# x{expfmt(self.n_repititions)}"
90+
return f"[{self.backer_id}] {self.tensor_id} sz {expfmt(self.tile_size)} above {self.above_loop_index}"# x{expfmt(self.n_repititions)}"
8991

9092
def __repr__(self):
9193
return f"TensorStorage({self.tensor_id}, {self.backer_id}, {self.above_loop_index}, {self.tile_size})"#, {self.n_repititions})"
9294

9395
def pydot_str(self):
94-
return f"({self.backer_id}) {self.tensor_id} size " \
95-
f"{expfmt(self.tile_size)}"#*{expfmt(self.n_repititions)}={expfmt(self.tile_size)}"# * self.n_repititions)}"
96+
return f"[{self.backer_id}] T{self.tensor_id} size {expfmt(self.tile_size)}"
97+
#*{expfmt(self.n_repititions)}={expfmt(self.tile_size)}"# * self.n_repititions)}"
9698

9799
def rename(self, rank_renaming: dict[str, str], tensor_renaming: dict[str, str]) -> "TensorStorage":
98100
return TensorStorage(
@@ -111,6 +113,10 @@ def to_yaml(self):
111113
"above_loop_index": self.above_loop_index,
112114
"tile_size": self.tile_size,
113115
}
116+
117+
class TensorStorage2(TensorStorage):
118+
def __repr__(self):
119+
return f"TensorStorage2({self.tensor_id}, {self.backer_id}, {self.above_loop_index}, {self.tile_size})"
114120

115121

116122
@dataclass(frozen=True)
@@ -209,6 +215,8 @@ def merge_next(self, n: "SIM", next_live_tensors: set[str], resource2capacity: d
209215
shared_loop_index = self.tiling.shared_loop_index(n.tiling.tensor_names)
210216
tiling = n.tiling.absorb_tensors(self.tiling, next_live_tensors)
211217
next_shared_loop_index = tiling.shared_loop_index(next_live_tensors)
218+
# assert all(t.tensor_id in next_live_tensors for t in tiling.tensors), f"Did not free all dead tensors {tiling.tensors} {next_live_tensors}"
219+
# assert all
212220
mapping = self.mapping.merge(n.mapping, shared_loop_index, next_shared_loop_index, resource2capacity, next_live_tensors, delay=delay)
213221
s = SIM(tiling, mapping)
214222
assert len(tiling.loops) == next_shared_loop_index + 1, f"{self.tiling} {n.tiling} {next_shared_loop_index + 1} -> {tiling} {len(tiling.loops)}"
@@ -220,8 +228,10 @@ def get_shared_loop_index(self, next_live_tensors: set[str]) -> int:
220228
live_tensors = list(self.tiling.tensor_names) + [next_live_tensors]
221229
return self.tiling.shared_loop_index(live_tensors)
222230

223-
def consolidate(self, next_live_tensors: set[str] = None, resource2capacity: dict[str, int] = None):
231+
def consolidate(self, next_live_tensors: set[str] = None, resource2capacity: dict[str, int] = None, shared_tensors: set[str] = None):
224232
dead_tensors = set(self.tensors) - (next_live_tensors or set())
233+
shared_tensors = shared_tensors or set()
234+
shared_loop_index = self.tiling.shared_loop_index(shared_tensors | next_live_tensors)
225235
for t in dead_tensors:
226236
self._free_tensor(t)
227237
if next_live_tensors is None:
@@ -231,9 +241,8 @@ def consolidate(self, next_live_tensors: set[str] = None, resource2capacity: dic
231241
# Can free the deepest of:
232242
# - The shared loop with the next SIM
233243
# - My deepest loop that hasn't yet been freed
234-
shared_loop_index = self.tiling.shared_loop_index(next_live_tensors)
235-
if self.tensors:
236-
shared_loop_index = max(shared_loop_index, max(t.above_loop_index for t in self.tensors.values()))
244+
# if self.tensors:
245+
# shared_loop_index = max(shared_loop_index, max(t.above_loop_index for t in self.tensors.values()))
237246
self.mapping.free_to_loop_index(shared_loop_index+1, resource2capacity)
238247

239248
def __eq__(self, other):

0 commit comments

Comments
 (0)