Add edge counter functionality and tests #535
Add edge counter functionality and tests #535kajalpatelinfo wants to merge 39 commits intoinducer:mainfrom
Conversation
Co-authored-by: Matt Smith <mjsmith1@gmail.com>
Co-authored-by: Andreas Klöckner <inform@tiker.net>
majosm
left a comment
There was a problem hiding this comment.
Thanks! I added a few suggestions.
Also, there's an edge (ha) case that we need to think about here: what happens if the same array appears twice as a dependency? For example, y = x + x; should there be one edge between x and y, or two? The behavior should ideally match whatever the visualization does (I don't remember offhand if it emits a single edge or multiple). If it needs to be multiple edges, I think one way to do that would be to tweak DirectPredecessorsGetter to return either a list or a frozenset based on an argument to __init__.
| AxisPermutation, | ||
| BasicIndex, | ||
| Concatenate, | ||
| DataWrapper as DataWrapper, |
There was a problem hiding this comment.
| DataWrapper as DataWrapper, | |
| DataWrapper, |
| from pytato.transform import ( | ||
| ArrayOrNames, | ||
| CachedWalkMapper, | ||
| DependencyMapper as DependencyMapper, |
There was a problem hiding this comment.
| DependencyMapper as DependencyMapper, | |
| DependencyMapper, |
| # Each dependency is connected by an edge | ||
| self.edge_count += len(self.get_dependencies(expr)) | ||
|
|
||
| def get_dependencies(self, expr: Any) -> frozenset[Any]: | ||
| # Retrieve dependencies based on the type of the expression | ||
| if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): | ||
| return frozenset(expr.bindings.values()) | ||
| elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): | ||
| return frozenset([expr.array]) | ||
| elif isinstance(expr, Einsum): | ||
| return frozenset(expr.args) | ||
| return frozenset() |
There was a problem hiding this comment.
Tempted to say that this and the DirectPredecessorsGetter implementation in the tests should be swapped. DirectPredecessorsGetter seems like the more "proper" way to do this, and get_dependencies makes sense as an alternate implementation to check that it's working.
|
|
||
|
|
||
| def get_num_edges(outputs: Array | DictOfNamedArrays, | ||
| count_duplicates: bool | None = None) -> int: |
There was a problem hiding this comment.
| count_duplicates: bool | None = None) -> int: | |
| count_duplicates: bool = False) -> int: |
(Since get_num_edges is a new function, we don't have to keep the deprecation stuff that get_num_nodes has.)
| if count_duplicates is None: | ||
| from warnings import warn | ||
| warn( | ||
| "The default value of 'count_duplicates' will change " | ||
| "from True to False in 2025. " | ||
| "For now, pass the desired value explicitly.", | ||
| DeprecationWarning, stacklevel=2) | ||
| count_duplicates = True |
There was a problem hiding this comment.
| if count_duplicates is None: | |
| from warnings import warn | |
| warn( | |
| "The default value of 'count_duplicates' will change " | |
| "from True to False in 2025. " | |
| "For now, pass the desired value explicitly.", | |
| DeprecationWarning, stacklevel=2) | |
| count_duplicates = True |
|
|
||
| def post_visit(self, expr: Any) -> None: | ||
| # Each dependency is connected by an edge | ||
| self.edge_count += len(self.get_dependencies(expr)) |
There was a problem hiding this comment.
Probably will also want an if not isinstance(expr, DictOfNamedArrays): check here if switching to DirectPredecessorsGetter.
| for dep in dependencies: | ||
| self.edge_multiplicity_counts[dep, expr] += 1 | ||
|
|
||
| def get_dependencies(self, expr: Any) -> frozenset[Any]: | ||
| # Retrieve dependencies based on the type of the expression | ||
| if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): | ||
| return frozenset(expr.bindings.values()) | ||
| elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): | ||
| return frozenset([expr.array]) | ||
| elif isinstance(expr, Einsum): | ||
| return frozenset(expr.args) | ||
| return frozenset() |
There was a problem hiding this comment.
(Same deal here with DirectPredecessorsGetter.)
| empty_dag = pt.make_dict_of_named_arrays({}) | ||
|
|
||
| # Verify that get_num_edges returns 0 for an empty DAG | ||
| assert get_num_edges(empty_dag, count_duplicates=False) == 0 |
There was a problem hiding this comment.
| assert get_num_edges(empty_dag, count_duplicates=False) == 0 | |
| assert get_num_edges(empty_dag) == 0 |
(And same for all the rest.)
This PR adds functionality to count the number of edges in a DAG, with or without duplicates. It also adds tests for these functionalities.