Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ def _get_node_type(n: node.Node) -> str:
else:
return "function"

def _is_async_node(n: node.Node) -> bool:
"""Returns whether a DAG node is backed by an async callable."""
return n.callable is not None and inspect.iscoroutinefunction(n.callable)

def _get_node_style(node_type: str) -> dict[str, str]:
"""Get the style of a node type.
Graphviz needs values to be strings.
Expand Down Expand Up @@ -408,6 +412,8 @@ def _get_function_modifier_style(modifier: str) -> dict[str, str]:
modifier_style = dict(style="filled,diagonals")
elif modifier == "materializer":
modifier_style = dict(shape="cylinder")
elif modifier == "async":
modifier_style = dict(fillcolor="#CDB4DB", style="rounded,filled,bold")
elif modifier == "field":
modifier_style = dict(fillcolor="#c8dae0", fontname="Courier")
elif modifier == "cluster":
Expand Down Expand Up @@ -457,6 +463,7 @@ def _get_legend(
"config",
"input",
"function",
"async",
"cluster",
"field",
"output",
Expand Down Expand Up @@ -565,6 +572,11 @@ def _get_legend(
node_style.update(**modifier_style)
seen_node_types.add("materializer")

if _is_async_node(n):
modifier_style = _get_function_modifier_style("async")
node_style.update(**modifier_style)
seen_node_types.add("async")

# apply custom styles before node modifiers
seen_node_type = None
if custom_style_function:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,37 @@ def test_create_graphviz_graph():
assert dot_set == expected_set


def test_create_graphviz_graph_styles_async_nodes():
async def async_node() -> int:
return 1

def sync_node(async_node: int) -> int:
return async_node + 1

module = ad_hoc_utils.create_temporary_module(async_node, sync_node)
fg = graph.FunctionGraph.from_modules(module, config={})

digraph = graph.create_graphviz_graph(
set(fg.get_nodes()),
"Dependency Graph\n",
graphviz_kwargs={},
node_modifiers={},
strictly_display_only_nodes_passed_in=False,
config={},
)
dot_source = str(digraph)

assert (
"\tasync_node [label=<<b>async_node</b><br /><br /><i>int</i>> "
'fillcolor="#CDB4DB" fontname=Helvetica margin=0.15 shape=rectangle '
'style="rounded,filled,bold"]'
) in dot_source
assert (
'\t\tasync [fillcolor="#CDB4DB" fontname=Helvetica margin=0.15 '
'shape=rectangle style="rounded,filled,bold"]'
) in dot_source


def test_create_networkx_graph():
"""Tests that we create a networkx graph"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
Expand Down
Loading