1313 RESERVED_COLUMNS ,
1414 TENSORS ,
1515 IN_PROGRESS_STATS ,
16+ TAGS ,
1617)
1718from pytimeloop .fastfusion .sim import TensorStorage , Tiling , Loop
1819
@@ -37,7 +38,6 @@ def all_metrics(cls):
3738
3839# DEBUG_VISUALIZATION = Metrics.ALL_TENSORS | METRICS.PARTIAL_STATS
3940
40-
4141def process_result (
4242 result ,
4343 shape ,
@@ -54,6 +54,9 @@ def process_result(
5454 einsum_id_to_name ,
5555 rank_id_to_name ,
5656 tensor_id_to_name ,
57+ rank_name_to_shared_name ,
58+ input_tensors : set [str ],
59+ output_tensors : set [str ],
5760 logfunc = None ,
5861 metrics = Metrics .all_metrics (),
5962 tag_with : tuple [callable ] = (),
@@ -75,7 +78,7 @@ def process_result(
7578 )
7679
7780 cur_idx = 0
78- all_backing_storages = []
81+ backing_storages = []
7982 all_storages = []
8083 intermediates_to_find = set (intermediate_tensors )
8184 found_tensors = set ()
@@ -95,7 +98,7 @@ def record_storage(node):
9598 intermediates_to_find .remove (storage .tensor_id )
9699 if storage .tensor_id not in found_tensors :
97100 found_tensors .add (storage .tensor_id )
98- all_backing_storages .append (storage )
101+ backing_storages .append (storage )
99102
100103 logstring .append (f"Strg({ node ['dspace' ]} in { node ['target' ]} )" )
101104
@@ -107,6 +110,7 @@ def record_loop(node):
107110 tile_shape = shape [cur_idx ]
108111 cur_idx += 1
109112 rank_id = equiv_groups .rank_to_group_id [node ["rank" ]]
113+ # rank_id = node["rank"]
110114 loop = Loop (
111115 rank_id_to_name [rank_id ],
112116 tile_shape ,
@@ -124,21 +128,33 @@ def record_loop(node):
124128 elif node ["type" ] == "spatial" or node ["type" ] == "temporal" :
125129 record_loop (node )
126130
127- n_fused_loops = max (t .above_loop_index for t in all_backing_storages )
131+ n_fused_loops = max (t .above_loop_index for t in backing_storages )
128132 tiling_full = Tiling (
129133 loops = tuple (full_tiling ),
130134 tensors = frozenset (all_storages ),
131135 )
136+
137+ tagger_args = dict (
138+ einsum_id = einsum_id ,
139+ backing_storages = backing_storages ,
140+ input_tensors = input_tensors ,
141+ output_tensors = output_tensors ,
142+ tiling = tiling_full ,
143+ rank_name_to_shared_name = rank_name_to_shared_name ,
144+ )
145+ # print(tiling_full)
132146
133147 tiling_compatibility = Tiling (
134148 loops = tuple (full_tiling [:n_fused_loops ]),
135- tensors = frozenset (all_backing_storages ),
136- # tags=fzs().union(*([set()] + [set(t(einsum_id, tiling_full )) for t in tag_with]))
149+ tensors = frozenset (backing_storages ),
150+ tags = fzs ().union (* ([set ()] + [set (t (** tagger_args )) for t in tag_with ]))
137151 )
138-
139- # assert max(t.above_loop_index for t in all_backing_storages) == len(tiling_compatibility.loops), (
152+
153+ if "FFMT_VALID" in tiling_compatibility .tags :
154+ print (tiling_compatibility )
155+ # assert max(t.above_loop_index for t in backing_storages) == len(tiling_compatibility.loops), (
140156 # f"\n\ttiling_compatibility: {tiling_compatibility} "
141- # f"\n\tall_backing_storages : {all_backing_storages } "
157+ # f"\n\tbacking_storages : {backing_storages } "
142158 # f"\n\ttiling_full: {tiling_full}"
143159 # )
144160
@@ -150,14 +166,14 @@ def record_loop(node):
150166 if Metrics .ENERGY in metrics :
151167 results ["Energy" ] = energy
152168
153- offchip_accesses = 0
169+ offchip_ac = 0
154170 for (level , tensor , einsum ), count in accesses .items ():
155171 if level == 0 :
156- offchip_accesses += count
172+ offchip_ac += count
157173 logstring .append (f"Ac_{ level } _{ tensor } ={ count :.2e} " )
158174
159175 if Metrics .OFF_CHIP_ACCESSES in metrics :
160- results ["Offchip_Ac " ] = offchip_accesses
176+ results ["Offchip Accesses " ] = offchip_ac
161177
162178 logstring .append (f"{ result .fanout } " )
163179
@@ -166,7 +182,7 @@ def record_loop(node):
166182 # be backed
167183 for r in all_storages :
168184 r : TensorStorage
169- if r not in all_backing_storages :
185+ if r not in backing_storages :
170186 key = nameloop2col (r .backer_id , r .above_loop_index )
171187 results .setdefault (key , 0 )
172188 results [key ] += r .tile_size
@@ -184,15 +200,18 @@ def record_loop(node):
184200 logstring .append (f"Results: { results } " )
185201 results [LOGSTRING ] = {einsum_id : str (logstring )}
186202 results [MAPPING ] = {einsum_id : tiling_full }
187- results [TENSORS ] = {einsum_id : all_backing_storages }
203+ results [TENSORS ] = {einsum_id : backing_storages }
188204 results [STATS ] = {
189205 einsum_id : {k : v for k , v in results .items () if k not in RESERVED_COLUMNS }
190206 }
191207 results [IN_PROGRESS_STATS ] = {einsum_id : {}}
192208 results [MAPPING_HASH ] = {einsum_id : hash ((einsum_id , tiling_compatibility ))}
209+ results [TAGS ] = {einsum_id : tiling_compatibility .tags }
210+
211+ key = (tiling_compatibility , fzs (results .keys ()))
193212
194213 is_pareto = True
195- for prev_stats in compatibility_to_df [tiling_compatibility ]:
214+ for prev_stats in compatibility_to_df [key ]:
196215 keys = [k for k in results if k not in DICT_COLUMNS ]
197216 if (
198217 fzs (prev_stats .keys ()) == fzs (results .keys ())
@@ -204,6 +223,6 @@ def record_loop(node):
204223 # TO DO: Index into the DF with both tiling compatibility and
205224 # the result keys
206225 if is_pareto :
207- compatibility_to_df [tiling_compatibility ].append (results )
226+ compatibility_to_df [key ].append (results )
208227 results_return = {k : v for k , v in results .items () if k != LOGSTRING }
209228 return is_pareto , results_return , logstring
0 commit comments