Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def list_available_variables(
is non-empty we will return all nodes with that tag and that value.
:return: list of available variables (i.e. outputs).
"""
all_nodes = self.graph.get_nodes()
all_nodes = self.graph.get_nodes_in_topological_order()
if tag_filter:
valid_filter_values = all(
map(
Expand Down
34 changes: 34 additions & 0 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,40 @@ def decorator_counter(self) -> dict[str, int]:
def get_nodes(self) -> list[node.Node]:
return list(self.nodes.values())

def get_nodes_in_topological_order(self) -> list[node.Node]:
"""Returns nodes in dependency-first topological order.

This preserves the graph's existing insertion order for otherwise independent nodes.
"""
visited = set()
visiting = set()
ordered_nodes = []

for start_node in self.nodes.values():
stack = [(start_node, False)]
while stack:
node_, expanded = stack.pop()
node_name = node_.name
if node_name in visited:
continue

if expanded:
visiting.discard(node_name)
visited.add(node_name)
ordered_nodes.append(node_)
continue

if node_name in visiting:
continue

visiting.add(node_name)
stack.append((node_, True))
for dependency in reversed(node_.dependencies):
if dependency.name in self.nodes and dependency.name not in visited:
stack.append((dependency, False))

return ordered_nodes

def display_all(
self,
output_file_path: str = "test-output/graph-all.gv",
Expand Down
15 changes: 14 additions & 1 deletion tests/test_hamilton_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytest

from hamilton import base, node
from hamilton import ad_hoc_utils, base, node
from hamilton.caching.adapter import HamiltonCacheAdapter
from hamilton.driver import (
Builder,
Expand Down Expand Up @@ -149,6 +149,19 @@ def test_driver_variables_exposes_tags():
assert tags["d"] == {"module": "tests.resources.tagging", "test_list": ["us", "uk"]}


def test_driver_variables_are_topologically_sorted():
def z_dependency() -> int:
return 1

def a_final(z_dependency: int) -> int:
return z_dependency + 1

module = ad_hoc_utils.create_temporary_module(z_dependency, a_final)
dr = Driver({}, module)

assert [var.name for var in dr.list_available_variables()] == ["z_dependency", "a_final"]


@pytest.mark.parametrize(
("filter", "expected"),
[
Expand Down