@@ -16,7 +16,8 @@ def make_storage(
1616 add_split_at_tensors : Set = None ,
1717 must_have_terminal_storage : bool = False ,
1818 logfunc : Callable = None ,
19- return_retained_tensors : bool = False
19+ return_retained_tensors : bool = False ,
20+ automatically_lower_below_relevant_ranks : bool = False
2021):
2122 if logfunc is None :
2223 logfunc = lambda msg : None # do nothing
@@ -59,17 +60,19 @@ def make_storage(
5960 last_is_relevant = True
6061
6162 min_i = get_last_storage_node (mapping , tensor_id )
62- for i , node in enumerate (mapping [min_i + 1 :]):
63- i += min_i + 1
64- if node ["type" ] == "temporal" :
65- rank_id = node ["rank" ]
66- is_relevant = rank_id in relevant_ranks
67- if last_is_relevant and not is_relevant :
68- # Choice 1: fused
69- tensor_choices .append (i )
70- if tensor_must_be_fully_reused :
71- break
72- last_is_relevant = is_relevant
63+ if automatically_lower_below_relevant_ranks :
64+ for i , node in enumerate (mapping [min_i + 1 :]):
65+ i += min_i + 1
66+ if node ["type" ] == "temporal" :
67+ rank_id = node ["rank" ]
68+ is_relevant = rank_id in relevant_ranks
69+ if ((last_is_relevant and not is_relevant )
70+ or not automatically_lower_below_relevant_ranks ):
71+ # Choice 1: fused
72+ tensor_choices .append (i )
73+ if tensor_must_be_fully_reused :
74+ break
75+ last_is_relevant = is_relevant
7376
7477 # There has not been a single irrelevant loop
7578 if last_is_relevant and (not tensor_must_be_fully_reused
0 commit comments