Skip to content

Commit 898f0b1

Browse files
WIP tagging
1 parent 35a3f80 commit 898f0b1

10 files changed

Lines changed: 249 additions & 61 deletions

File tree

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from collections import defaultdict
2+
from pytimeloop.fastfusion.sim import Tag, TensorStorage, Tiling
3+
4+
def get_ffmt_tag_mha(
5+
einsum_id: str,
6+
backing_storages: set[TensorStorage],
7+
input_tensors: set[str],
8+
output_tensors: set[str],
9+
tiling: Tiling,
10+
rank_name_to_shared_name: dict[str, str]
11+
):
12+
B, H, M, F, P, G, E, D = (x + einsum_id for x in "BHMFPGED")
13+
14+
einsum_id_to_input_output = {
15+
"Q": ["I_I_to_Q_K_V", "Q_Q_to_QK"],
16+
"K": ["I_I_to_Q_K_V", None], # NOTE: TANNER ADDED THESE
17+
"V": ["I_I_to_Q_K_V", None], # NOTE: TANNER ADDED THESE
18+
"QK": ["Q_Q_to_QK", "QK_QK_to_AV"], # NOTE: K IS MISSING
19+
"AV": ["QK_QK_to_AV", "AV_AV_to_Z"], # NOTE: AV IS MISSING
20+
"Z": ["AV_AV_to_Z", "Z_Z_to_n"],
21+
}
22+
a, b = einsum_id_to_input_output[einsum_id]
23+
24+
tags = []
25+
26+
min_weight_index = None
27+
max_non_weight_index = 0
28+
first, last = True, True
29+
for t in tiling.tensors:
30+
if t.backer_id != 1:
31+
continue
32+
if t.tensor_id in input_tensors and t in backing_storages:
33+
first = False
34+
if t.tensor_id in output_tensors and t in backing_storages:
35+
last = False
36+
# if "W_n_to_" in t.tensor_id:
37+
if t.tensor_id != a and t.tensor_id != b: # Weights!
38+
if min_weight_index is None:
39+
min_weight_index = t.above_loop_index
40+
else:
41+
min_weight_index = min(min_weight_index, t.above_loop_index)
42+
else:
43+
max_non_weight_index = max(max_non_weight_index, t.above_loop_index)
44+
45+
if min_weight_index == 2:
46+
tags.append("FFMT_WEIGHTS_UNTILED")
47+
elif min_weight_index is None or min_weight_index < max_non_weight_index:
48+
tags.append("FFMT_WEIGHTS_INVALID")
49+
else:
50+
tags.append("FFMT_WEIGHTS_TILED")
51+
52+
to_try = [([B, H], (2, 2)), ([B, H, M], (3, 3))]
53+
other_ranks = {
54+
"Q": [B, H, M, E, D],
55+
"K": [B, H, M, E, D],
56+
"V": [B, H, M, E, D],
57+
"QK": [B, H, M, P, E],
58+
"AV": [B, H, M, F, P],
59+
"Z": [B, H, M, G],
60+
}[einsum_id]
61+
62+
valid = False
63+
if first and last: # Unfused
64+
to_try = []
65+
valid = True
66+
tags.append("FFMT_UNFUSED")
67+
elif first: # First Einsum in a chain
68+
to_try += [(other_ranks[:4], (3, 4)), (other_ranks, (5, 4))]
69+
tags.append("FFMT_FIRST")
70+
elif last: # Last Einsum in a chain
71+
to_try += [(other_ranks[4:], (3, 4))]
72+
tags.append("FFMT_LAST")
73+
else: # Middle Einsum in a chain
74+
if einsum_id == "AV":
75+
a, b = b, a
76+
other_ranks[-2], other_ranks[-1] = other_ranks[-1], other_ranks[-2]
77+
to_try += [(other_ranks[:4], (3, 4))]
78+
tags.append("FFMT_MIDDLE")
79+
80+
for i, (c, (a_loops, b_loops)) in enumerate(to_try):
81+
perm = [rank_name_to_shared_name[x] for x in c] + ["*"]
82+
check_tensors = [TensorStorage(a, a_loops, 1, "*")]
83+
if b is not None:
84+
check_tensors.append(TensorStorage(b, b_loops, 1, "*"))
85+
if tiling.matches_permutation(perm):
86+
valid = valid
87+
if tiling.matches_permutation(perm) and tiling.has_tensor(*check_tensors):
88+
valid = True
89+
# tags.append(f"FFMT_VALID_{i}")
90+
91+
# return ("FFMT_VALID" if valid else "FFMT_INVALID", weight_tag)
92+
if valid:# and weight_tag != "INVALID":
93+
return ("FFMT_VALID", *tags)
94+
return ("FFMT_INVALID",)
95+
96+
def get_tileflow_tag_mha(
97+
einsum_id: str,
98+
backing_storages: set[TensorStorage],
99+
input_tensors: set[str],
100+
output_tensors: set[str],
101+
tiling: Tiling,
102+
rank_name_to_shared_name: dict[str, str]
103+
):
104+
# Valid iff it's an even mapping
105+
storage2level = defaultdict(set)
106+
for ts in tiling.tensors:
107+
storage2level[ts.backer_id].add(ts.above_loop_index)
108+
if all(len(s) == 1 for s in storage2level.values()):
109+
return ("TILEFLOW_VALID",)
110+
return ("TILEFLOW_INVALID",)

pytimeloop/fastfusion/mapper/mapper.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,7 @@ def generate_data(from_einsum: int, to_einsum: int, data, rank_renaming, tensor_
146146

147147

148148
def _convert_tiling(tiling: Tiling, rank_renaming, tensor_renaming):
149-
return Tiling(
150-
loops=tuple(l.rename(rank_renaming, tensor_renaming) for l in tiling.loops),
151-
tensors=frozenset(ts.rename(rank_renaming, tensor_renaming) for ts in tiling.tensors),
152-
)
149+
return tiling.rename(rank_renaming, tensor_renaming)
153150

154151

155152
def _convert_stats(from_einsum: int, to_einsum: int, stats, rank_renaming, tensor_renaming):

pytimeloop/fastfusion/mapper/mapper_snowcat.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,13 @@ def mapper(
6363
separated_einsums = None
6464
else:
6565
separated_einsums = get_ffmt_separated_einsums(workload)
66-
grouped_similar_einsums = convert_rank_to_group_renaming(
67-
detect_similar_einsums(workload, analyzer, separated_einsums),
68-
equivalent_groups
69-
)
66+
if not tag_with:
67+
grouped_similar_einsums = convert_rank_to_group_renaming(
68+
detect_similar_einsums(workload, analyzer, separated_einsums),
69+
equivalent_groups
70+
)
71+
else:
72+
grouped_similar_einsums = {einsum: {} for einsum in workload.einsum_id_to_name()}
7073
logger.info(f"Found {len(grouped_similar_einsums)} unique Einsums\n"
7174
+ f"\tConverter: {grouped_similar_einsums}")
7275

@@ -128,15 +131,7 @@ def generate_data(from_einsum: int, to_einsum: int, data, rank_renaming, tensor_
128131

129132

130133
def _convert_tiling(tiling: Tiling, rank_renaming, tensor_renaming):
131-
return Tiling(
132-
loops=tuple(Loop(rank_renaming[l.rank_id], l.bound, l.is_spatial)
133-
for l in tiling.loops),
134-
tensors=frozenset(TensorStorage(tensor_renaming[ts.tensor_id],
135-
ts.above_loop_index,
136-
ts.backer_id,
137-
ts.tile_size)
138-
for ts in tiling.tensors)
139-
)
134+
return tiling.rename(rank_renaming, tensor_renaming)
140135

141136

142137
def _convert_stats(from_einsum: int, to_einsum: int, stats, rank_renaming, tensor_renaming):

pytimeloop/fastfusion/mapper/per_einsum_mapper_snowcat.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ def per_einsum_mapper_snowcat(
8686

8787
partial_mappings = list(dependent_product(parallelized_spaces))
8888
partial_mappings = [x if isinstance(x, tuple) else (x,) for x in partial_mappings]
89+
rank_id_to_name = {v: k for k, v in rank_name_to_id.items()}
90+
tensor_id_to_name = {v: k for k, v in tensor_name_to_id.items()}
91+
input_tensors = set(tensor_id_to_name[t] for t in workload.tensors_read_by_einsum(einsum_id))
92+
output_tensors = set(tensor_id_to_name[t] for t in workload.tensors_written_by_einsum(einsum_id))
93+
rank_name_to_shared_name = {
94+
rank_id_to_name[k]: rank_id_to_name[v] for k, v in equivalent_groups.rank_to_group_id.items()
95+
}
8996

9097
# successful_partial_mappings = []
9198
# for p in partial_mappings:
@@ -148,8 +155,11 @@ def per_worker_exploration(*args):
148155
einsum_shape=einsum_shape,
149156
metrics=metrics,
150157
einsum_id_to_name=einsum_id_to_name,
151-
rank_id_to_name={v: k for k, v in rank_name_to_id.items()},
152-
tensor_id_to_name={v: k for k, v in tensor_name_to_id.items()},
158+
rank_id_to_name=rank_id_to_name,
159+
tensor_id_to_name=tensor_id_to_name,
160+
rank_name_to_shared_name=rank_name_to_shared_name,
161+
input_tensors=input_tensors,
162+
output_tensors=output_tensors,
153163
tag_with=tag_with,
154164
)
155165
return result
@@ -161,7 +171,7 @@ def per_worker_exploration(*args):
161171
data[einsum_id] = defaultdict(list)
162172
for res in results:
163173
for k, v in res.items():
164-
data[einsum_id][k] += v
174+
data[einsum_id][k[0]] += v
165175

166176
return data
167177

pytimeloop/fastfusion/mapper/process_results.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
RESERVED_COLUMNS,
1414
TENSORS,
1515
IN_PROGRESS_STATS,
16+
TAGS,
1617
)
1718
from pytimeloop.fastfusion.sim import TensorStorage, Tiling, Loop
1819

@@ -37,7 +38,6 @@ def all_metrics(cls):
3738

3839
# DEBUG_VISUALIZATION = Metrics.ALL_TENSORS | METRICS.PARTIAL_STATS
3940

40-
4141
def process_result(
4242
result,
4343
shape,
@@ -54,6 +54,9 @@ def process_result(
5454
einsum_id_to_name,
5555
rank_id_to_name,
5656
tensor_id_to_name,
57+
rank_name_to_shared_name,
58+
input_tensors: set[str],
59+
output_tensors: set[str],
5760
logfunc=None,
5861
metrics=Metrics.all_metrics(),
5962
tag_with: tuple[callable] = (),
@@ -75,7 +78,7 @@ def process_result(
7578
)
7679

7780
cur_idx = 0
78-
all_backing_storages = []
81+
backing_storages = []
7982
all_storages = []
8083
intermediates_to_find = set(intermediate_tensors)
8184
found_tensors = set()
@@ -95,7 +98,7 @@ def record_storage(node):
9598
intermediates_to_find.remove(storage.tensor_id)
9699
if storage.tensor_id not in found_tensors:
97100
found_tensors.add(storage.tensor_id)
98-
all_backing_storages.append(storage)
101+
backing_storages.append(storage)
99102

100103
logstring.append(f"Strg({node['dspace']} in {node['target']})")
101104

@@ -107,6 +110,7 @@ def record_loop(node):
107110
tile_shape = shape[cur_idx]
108111
cur_idx += 1
109112
rank_id = equiv_groups.rank_to_group_id[node["rank"]]
113+
# rank_id = node["rank"]
110114
loop = Loop(
111115
rank_id_to_name[rank_id],
112116
tile_shape,
@@ -124,21 +128,33 @@ def record_loop(node):
124128
elif node["type"] == "spatial" or node["type"] == "temporal":
125129
record_loop(node)
126130

127-
n_fused_loops = max(t.above_loop_index for t in all_backing_storages)
131+
n_fused_loops = max(t.above_loop_index for t in backing_storages)
128132
tiling_full = Tiling(
129133
loops=tuple(full_tiling),
130134
tensors=frozenset(all_storages),
131135
)
136+
137+
tagger_args = dict(
138+
einsum_id=einsum_id,
139+
backing_storages=backing_storages,
140+
input_tensors=input_tensors,
141+
output_tensors=output_tensors,
142+
tiling=tiling_full,
143+
rank_name_to_shared_name=rank_name_to_shared_name,
144+
)
145+
# print(tiling_full)
132146

133147
tiling_compatibility = Tiling(
134148
loops=tuple(full_tiling[:n_fused_loops]),
135-
tensors=frozenset(all_backing_storages),
136-
# tags=fzs().union(*([set()] + [set(t(einsum_id, tiling_full)) for t in tag_with]))
149+
tensors=frozenset(backing_storages),
150+
tags=fzs().union(*([set()] + [set(t(**tagger_args)) for t in tag_with]))
137151
)
138-
139-
# assert max(t.above_loop_index for t in all_backing_storages) == len(tiling_compatibility.loops), (
152+
153+
if "FFMT_VALID" in tiling_compatibility.tags:
154+
print(tiling_compatibility)
155+
# assert max(t.above_loop_index for t in backing_storages) == len(tiling_compatibility.loops), (
140156
# f"\n\ttiling_compatibility: {tiling_compatibility} "
141-
# f"\n\tall_backing_storages: {all_backing_storages} "
157+
# f"\n\tbacking_storages: {backing_storages} "
142158
# f"\n\ttiling_full: {tiling_full}"
143159
# )
144160

@@ -150,14 +166,14 @@ def record_loop(node):
150166
if Metrics.ENERGY in metrics:
151167
results["Energy"] = energy
152168

153-
offchip_accesses = 0
169+
offchip_ac = 0
154170
for (level, tensor, einsum), count in accesses.items():
155171
if level == 0:
156-
offchip_accesses += count
172+
offchip_ac += count
157173
logstring.append(f"Ac_{level}_{tensor}={count:.2e}")
158174

159175
if Metrics.OFF_CHIP_ACCESSES in metrics:
160-
results["Offchip_Ac"] = offchip_accesses
176+
results["Offchip Accesses"] = offchip_ac
161177

162178
logstring.append(f"{result.fanout}")
163179

@@ -166,7 +182,7 @@ def record_loop(node):
166182
# be backed
167183
for r in all_storages:
168184
r: TensorStorage
169-
if r not in all_backing_storages:
185+
if r not in backing_storages:
170186
key = nameloop2col(r.backer_id, r.above_loop_index)
171187
results.setdefault(key, 0)
172188
results[key] += r.tile_size
@@ -184,15 +200,18 @@ def record_loop(node):
184200
logstring.append(f"Results: {results}")
185201
results[LOGSTRING] = {einsum_id: str(logstring)}
186202
results[MAPPING] = {einsum_id: tiling_full}
187-
results[TENSORS] = {einsum_id: all_backing_storages}
203+
results[TENSORS] = {einsum_id: backing_storages}
188204
results[STATS] = {
189205
einsum_id: {k: v for k, v in results.items() if k not in RESERVED_COLUMNS}
190206
}
191207
results[IN_PROGRESS_STATS] = {einsum_id: {}}
192208
results[MAPPING_HASH] = {einsum_id: hash((einsum_id, tiling_compatibility))}
209+
results[TAGS] = {einsum_id: tiling_compatibility.tags}
210+
211+
key = (tiling_compatibility, fzs(results.keys()))
193212

194213
is_pareto = True
195-
for prev_stats in compatibility_to_df[tiling_compatibility]:
214+
for prev_stats in compatibility_to_df[key]:
196215
keys = [k for k in results if k not in DICT_COLUMNS]
197216
if (
198217
fzs(prev_stats.keys()) == fzs(results.keys())
@@ -204,6 +223,6 @@ def record_loop(node):
204223
# TO DO: Index into the DF with both tiling compatibility and
205224
# the result keys
206225
if is_pareto:
207-
compatibility_to_df[tiling_compatibility].append(results)
226+
compatibility_to_df[key].append(results)
208227
results_return = {k: v for k, v in results.items() if k != LOGSTRING}
209228
return is_pareto, results_return, logstring

pytimeloop/fastfusion/mapper/simexplore.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def fuse_sims(
100100
resource2capacity=resource2capacity,
101101
shared_tensors=set(),
102102
)
103+
104+
# TODO: Lookahead by one SIM. If we're going to create a tiling that has loops
105+
# that are not in the ranks of the next SIM, we should drop that tiling.
103106

104107
while sims:
105108
nbuckets.append(len(left))

pytimeloop/fastfusion/pareto.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626
TENSORS = "__TENSORS"
2727
IN_PROGRESS_STATS = "__IN_PROGRESS_STATS"
2828
MAPPING_HASH = "__MAPPING_HASH"
29+
TAGS = "__TAGS"
2930

3031
RESERVED_COLUMNS = set(
31-
[LOGSTRING, MAPPING, STATS, TENSORS, IN_PROGRESS_STATS, MAPPING_HASH]
32+
[LOGSTRING, MAPPING, STATS, TENSORS, IN_PROGRESS_STATS, MAPPING_HASH, TAGS]
3233
)
3334
DICT_COLUMNS = set(
34-
[LOGSTRING, MAPPING, STATS, TENSORS, IN_PROGRESS_STATS, MAPPING_HASH]
35+
[LOGSTRING, MAPPING, STATS, TENSORS, IN_PROGRESS_STATS, MAPPING_HASH, TAGS]
3536
)
3637

3738
_resource_name_nloops_reg = re.compile(r"RESOURCE_(.+?)(?:_LEFT)?_LEVEL_(-?\d+)")

0 commit comments

Comments
 (0)