Skip to content

Commit d9ce025

Browse files
committed
Add flag to turn off automatic storage node lower and use it in snowcat
1 parent afda2af commit d9ce025

2 files changed

Lines changed: 16 additions & 12 deletions

File tree

pytimeloop/fastfusion/mapper/per_einsum_subspaces/snowcat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def glb_storage(mapping, unfused_tensors):
5050
explore_uneven=True,
5151
add_split_at_tensors=glb_fused_tensors,
5252
must_have_terminal_storage=True,
53+
automatically_lower_below_relevant_ranks=False,
5354
)
5455

5556
def tile_shape_optimization(mapping):

pytimeloop/fastfusion/mapper/per_einsum_subspaces/subspaces/storage.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def make_storage(
1616
add_split_at_tensors: Set=None,
1717
must_have_terminal_storage: bool=False,
1818
logfunc: Callable=None,
19-
return_retained_tensors: bool=False
19+
return_retained_tensors: bool=False,
20+
automatically_lower_below_relevant_ranks: bool = False
2021
):
2122
if logfunc is None:
2223
logfunc = lambda msg: None # do nothing
@@ -59,17 +60,19 @@ def make_storage(
5960
last_is_relevant = True
6061

6162
min_i = get_last_storage_node(mapping, tensor_id)
62-
for i, node in enumerate(mapping[min_i+1:]):
63-
i += min_i+1
64-
if node["type"] == "temporal":
65-
rank_id = node["rank"]
66-
is_relevant = rank_id in relevant_ranks
67-
if last_is_relevant and not is_relevant:
68-
# Choice 1: fused
69-
tensor_choices.append(i)
70-
if tensor_must_be_fully_reused:
71-
break
72-
last_is_relevant = is_relevant
63+
if automatically_lower_below_relevant_ranks:
64+
for i, node in enumerate(mapping[min_i+1:]):
65+
i += min_i+1
66+
if node["type"] == "temporal":
67+
rank_id = node["rank"]
68+
is_relevant = rank_id in relevant_ranks
69+
if ((last_is_relevant and not is_relevant)
70+
or not automatically_lower_below_relevant_ranks):
71+
# Choice 1: fused
72+
tensor_choices.append(i)
73+
if tensor_must_be_fully_reused:
74+
break
75+
last_is_relevant = is_relevant
7376

7477
# There has not been a single irrelevant loop
7578
if last_is_relevant and (not tensor_must_be_fully_reused

0 commit comments

Comments
 (0)