Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import abc
import functools
import importlib
import importlib.util
import json
Expand Down Expand Up @@ -792,6 +793,26 @@ def list_available_variables(
results = [Variable.from_node(n) for n in all_nodes]
return results

@functools.cached_property
def variables(self) -> dict[str, Variable]:
"""Returns all variables in the graph keyed by name."""
return {
node_name: Variable.from_node(node_) for node_name, node_ in self.graph.nodes.items()
}

def get_variable(self, name: str) -> Variable:
"""Returns a variable by name.

:param name: Name of the variable to return.
:return: Matching HamiltonNode.
:raises KeyError: If the variable does not exist in this Driver's graph.
"""
return self.variables[name]

def get_graph(self) -> graph_types.HamiltonGraph:
"""Returns the public HamiltonGraph representation for this Driver."""
return graph_types.HamiltonGraph.from_graph(self.graph)

@capture_function_usage
def display_all_functions(
self,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_hamilton_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ def test_driver_variables_exposes_original_function():
assert originating_functions["a"] == (tests.resources.very_simple_dag.b,) # a is an input


def test_driver_variable_lookup():
dr = Driver({}, tests.resources.very_simple_dag)

assert set(dr.variables) == {"a", "b"}
assert dr.variables["b"].name == "b"
assert dr.get_variable("a").is_external_input is True


def test_driver_get_graph_returns_hamilton_graph():
dr = Driver({}, tests.resources.very_simple_dag)

hamilton_graph = dr.get_graph()

assert hamilton_graph["b"].name == "b"


@pytest.mark.parametrize(
"driver_factory",
[
Expand Down