1- from pytimeloop .fastfusion .mapper .per_einsum_mapper import LinearMapping , make_storage , make_temporal_fors , make_temporal_fors_with_smallest_tile
1+ from pytimeloop .fastfusion .mapper .per_einsum_mapper import LinearMapping , make_storage , make_temporal_fors , make_temporal_fors_with_smallest_tile , make_temporal_fors_in_order
22
33def make_ffmt_subspaces (tensors ,
44 intermediate_tensors ,
55 tensor_to_relevant_ranks ,
66 einsum_id ,
7- workload ):
7+ workload ,
8+ refetch_weights : bool = True ):
9+
810 def off_chip_storage (mapping ):
911 off_chip_must_retain = tensors - intermediate_tensors
1012 off_chip_can_retain = intermediate_tensors
@@ -25,44 +27,80 @@ def off_chip_storage(mapping):
2527 M = all_ranks [0 ]
2628 N = all_ranks [1 ]
2729 K = all_ranks [2 ]
28-
29- if einsum_id == 0 :
30- allowed_fused_ranks = all_ranks
31- elif einsum_id == 1 :
32- allowed_fused_ranks = {M , K }
33- elif einsum_id == max (workload .einsum_id_to_name ().keys ()):
34- allowed_fused_ranks = {M , N }
35- else :
36- allowed_fused_ranks = {M }
30+ weight_tensor = None
31+ input_tensor = None
32+ for tensor_id in workload .tensors_read_by_einsum (einsum_id ):
33+ if tensor_to_relevant_ranks [tensor_id ] == {K , N }:
34+ weight_tensor = tensor_id
35+ elif tensor_to_relevant_ranks [tensor_id ] == {M , K }:
36+ input_tensor = tensor_id
37+ assert weight_tensor is not None
38+ assert input_tensor is not None
39+ output_tensor = next (iter (workload .tensors_written_by_einsum (einsum_id )))
40+ non_weight_tensor = tensors - {weight_tensor }
3741
3842 def fused_temporal_fors (mapping , unfused_tensors ):
39- for partial_mapping in make_temporal_fors (mapping , allowed_fused_ranks ):
40- # for partial_mapping in make_temporal_fors(mapping, all_ranks):
41- for partial_mapping in make_temporal_fors_with_smallest_tile (mapping , all_ranks ):
42- yield partial_mapping , unfused_tensors
43+ if input_tensor in unfused_tensors :
44+ allowed_fused_ranks = [M , N , K ]
45+ elif output_tensor in unfused_tensors :
46+ allowed_fused_ranks = [M , N ]
47+ else :
48+ allowed_fused_ranks = [M , K ]
49+ for partial_mapping in make_temporal_fors_in_order (mapping , allowed_fused_ranks ):
50+ yield partial_mapping , unfused_tensors
4351
4452
45- def glb_storage (mapping , unfused_tensors ):
53+ def glb_storage_io (mapping , unfused_tensors ):
4654 glb_fused_tensors = intermediate_tensors - unfused_tensors
4755 yield from make_storage (
4856 mapping ,
4957 level = 1 ,
50- must_retain_tensors = tensors ,
58+ must_retain_tensors = non_weight_tensor ,
5159 can_retain_tensors = set (),
5260 must_fully_reuse_tensors = glb_fused_tensors ,
5361 tensor_to_relevant_ranks = tensor_to_relevant_ranks ,
54- explore_uneven = True ,
55- add_split_at_tensors = glb_fused_tensors
62+ explore_uneven = False ,
63+ add_split_at_tensors = glb_fused_tensors ,
64+ return_retained_tensors = True ,
5665 )
5766
58- def mac (mapping ):
67+ def intra_temporal_fors (mapping , _ ):
68+ for partial_mapping in make_temporal_fors_with_smallest_tile (mapping ,
69+ {K , N }):
70+ yield partial_mapping , _
71+
72+ def glb_storage_weights (mapping , _ ):
73+ yield from make_storage (
74+ mapping ,
75+ level = 1 ,
76+ must_retain_tensors = {weight_tensor },
77+ can_retain_tensors = set (),
78+ tensor_to_relevant_ranks = tensor_to_relevant_ranks ,
79+ explore_uneven = False ,
80+ return_retained_tensors = True ,
81+ )
82+
83+ def mac (mapping , _ ):
5984 mapping .add_compute (einsum_id , 2 )
6085 yield mapping
6186
62- return [
63- lambda : [LinearMapping ()],
64- off_chip_storage ,
65- fused_temporal_fors ,
66- glb_storage ,
67- mac
68- ]
87+ if refetch_weights :
88+ return [
89+ lambda : [LinearMapping ()],
90+ off_chip_storage ,
91+ fused_temporal_fors ,
92+ glb_storage_io ,
93+ intra_temporal_fors ,
94+ glb_storage_weights ,
95+ mac
96+ ]
97+ else :
98+ return [
99+ lambda : [LinearMapping ()],
100+ off_chip_storage ,
101+ glb_storage_weights ,
102+ fused_temporal_fors ,
103+ glb_storage_io ,
104+ intra_temporal_fors ,
105+ mac
106+ ]
0 commit comments