Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,56 @@
from typing import Self

from burr.core.state import State

def _validate_declared_reads(fn: Callable, declared_reads: list[str]) -> None:
try:
source = inspect.getsource(fn)
except OSError:
return # skip if source unavailable

# detect actual state parameter name
sig = inspect.signature(fn)
state_param_name = None

for name, param in sig.parameters.items():
if param.annotation is State:
state_param_name = name
break

if state_param_name is None:
return


import textwrap
tree = ast.parse(textwrap.dedent(source))

declared = set(declared_reads)
violations = []

class Visitor(ast.NodeVisitor):
def visit_Subscript(self, node):
if (
isinstance(node.value, ast.Name)
and node.value.id == state_param_name

and isinstance(node.slice, ast.Constant)
and isinstance(node.slice.value, str)
):
key = node.slice.value
if key not in declared:
violations.append(key)
self.generic_visit(node)

Visitor().visit(tree)

if violations:
raise ValueError(
f"Action reads undeclared state keys: {violations}. "
f"Declared reads: {declared_reads}"
)



from burr.core.typing import ActionSchema

# This is here to make accessing the pydantic actions easier
Expand Down Expand Up @@ -628,6 +678,8 @@ def __init__(
self._fn = fn
self._reads = reads
self._writes = writes
_validate_declared_reads(self._originating_fn, self._reads)

self._bound_params = bound_params if bound_params is not None else {}
self._inputs = (
derive_inputs_from_fn(self._bound_params, self._fn)
Expand Down Expand Up @@ -1106,9 +1158,13 @@ def __init__(
:param writes:
"""
super(FunctionBasedStreamingAction, self).__init__()
self._originating_fn = originating_fn if originating_fn is not None else fn
self._fn = fn
self._reads = reads
self._writes = writes
_validate_declared_reads(self._originating_fn, self._reads)


self._bound_params = bound_params if bound_params is not None else {}
self._inputs = (
derive_inputs_from_fn(self._bound_params, self._fn)
Expand All @@ -1118,7 +1174,7 @@ def __init__(
[item for item in input_spec[1] if item not in self._bound_params],
)
)
self._originating_fn = originating_fn if originating_fn is not None else fn

self._schema = schema
self._tags = tags if tags is not None else []

Expand Down
23 changes: 23 additions & 0 deletions tests/core/test_action_reads_linter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements...


import pytest
from burr.core.action import action
from burr.core.state import State


def test_undeclared_state_read_raises_error():
with pytest.raises(ValueError):

@action(reads=["foo"], writes=[])
def bad_action(state: State):
x = state["bar"]
return {}, state


def test_declared_state_read_passes():
@action(reads=["foo"], writes=[])
def good_action(state: State):
x = state["foo"]
return {}, state