Skip to content

Commit 29f3610

Browse files
ENT-13075: Error if user tries to call bundle inside custom promise
Ticket: ENT-13075 Signed-off-by: Simon Halvorsen <simon.halvorsen@northern.tech>
1 parent dbad057 commit 29f3610

File tree

1 file changed

+64
-16
lines changed

1 file changed

+64
-16
lines changed

src/cfengine_cli/lint.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import itertools
1616
import tree_sitter_cfengine as tscfengine
17+
from dataclasses import dataclass
1718
from tree_sitter import Language, Parser
1819
from cfbs.validate import validate_config
1920
from cfbs.cfbs_config import CFBSConfig
@@ -26,6 +27,39 @@
2627
)
2728

2829

30+
@dataclass
31+
class _State:
32+
block_type: str | None = None # "bundle" | "body" | "promise" | None
33+
promise_type: str | None = None # "vars" | "files" | "classes" | ... | None
34+
attribute_name: str | None = None # "if" | "string" | "slist" | ... | None
35+
36+
def update(self, node) -> "_State":
37+
"""Updates and returns the state that should apply to the children of `node`."""
38+
if node.type == "bundle_block":
39+
return _State(block_type="bundle")
40+
if node.type == "body_block":
41+
return _State(block_type="body")
42+
if node.type == "promise_block":
43+
return _State(block_type="promise")
44+
if node.type == "bundle_section":
45+
for child in node.children:
46+
if child.type == "promise_guard":
47+
return _State(
48+
block_type=self.block_type,
49+
promise_type=_text(child)[:-1], # strip trailing ':'
50+
)
51+
return _State(block_type=self.block_type)
52+
if node.type == "attribute":
53+
for child in node.children:
54+
if child.type == "attribute_name":
55+
return _State(
56+
block_type=self.block_type,
57+
promise_type=self.promise_type,
58+
attribute_name=_text(child),
59+
)
60+
return self
61+
62+
2963
def lint_cfbs_json(filename) -> int:
3064
assert os.path.isfile(filename)
3165
assert filename.endswith("cfbs.json")
@@ -93,16 +127,9 @@ def _find_node_type(filename, lines, node, node_type):
93127
return matches
94128

95129

96-
def _find_nodes(filename, lines, node):
97-
matches = []
98-
visitor = lambda x: matches.append(x)
99-
_walk_generic(filename, lines, node, visitor)
100-
return matches
101-
102-
103-
def _single_node_checks(filename, lines, node, user_definition, strict):
104-
"""Things which can be checked by only looking at one node,
105-
not needing to recurse into children."""
130+
def _node_checks(filename, lines, node, user_definition, strict, state: _State):
131+
"""Checks we run on each node in the syntax tree,
132+
utilizes state for checks which require context."""
106133
line = node.range.start_point[0] + 1
107134
column = node.range.start_point[1] + 1
108135
if node.type == "attribute_name" and _text(node) == "ifvarclass":
@@ -133,7 +160,6 @@ def _single_node_checks(filename, lines, node, user_definition, strict):
133160
f"Error: Undefined promise type '{promise_type}' at {filename}:{line}:{column}"
134161
)
135162
return 1
136-
137163
if node.type == "bundle_block_name":
138164
if _text(node) != _text(node).lower():
139165
_highlight_range(node, lines)
@@ -156,6 +182,16 @@ def _single_node_checks(filename, lines, node, user_definition, strict):
156182
)
157183
return 1
158184
if node.type == "calling_identifier":
185+
if (
186+
strict
187+
and _text(node) in user_definition.get("all_bundle_names", set())
188+
and state.promise_type in user_definition.get("custom_promise_types", set())
189+
):
190+
_highlight_range(node, lines)
191+
print(
192+
f"Error: Call to bundle '{_text(node)}' inside custom promise: '{state.promise_type}' at {filename}:{line}:{column}"
193+
)
194+
return 1
159195
if strict and (
160196
_text(node)
161197
not in BUILTIN_FUNCTIONS.union(
@@ -171,6 +207,22 @@ def _single_node_checks(filename, lines, node, user_definition, strict):
171207
return 0
172208

173209

210+
def _stateful_walk(
211+
filename, lines, node, user_definition, strict, state: _State | None = None
212+
) -> int:
213+
if state is None:
214+
state = _State()
215+
216+
errors = _node_checks(filename, lines, node, user_definition, strict, state)
217+
218+
child_state = state.update(node)
219+
for child in node.children:
220+
errors += _stateful_walk(
221+
filename, lines, child, user_definition, strict, child_state
222+
)
223+
return errors
224+
225+
174226
def _walk(filename, lines, node, user_definition=None, strict=True) -> int:
175227
if user_definition is None:
176228
user_definition = {}
@@ -187,11 +239,7 @@ def _walk(filename, lines, node, user_definition=None, strict=True) -> int:
187239
line = node.range.start_point[0] + 1
188240
column = node.range.start_point[1] + 1
189241

190-
errors = 0
191-
for node in _find_nodes(filename, lines, node):
192-
errors += _single_node_checks(filename, lines, node, user_definition, strict)
193-
194-
return errors
242+
return _stateful_walk(filename, lines, node, user_definition, strict)
195243

196244

197245
def _parse_user_definition(filename, lines, root_node):

0 commit comments

Comments
 (0)