Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ReCallingOperatorResultKeys,
ProcessBoundOperatorMixin,
is_process_bound,
NestedInterpretationMixin,
)
from octobot_commons.dsl_interpreter.interpreter_dependency import (
InterpreterDependency,
Expand Down Expand Up @@ -83,6 +84,7 @@
"OperatorSignals",
"ProcessBoundOperatorMixin",
"is_process_bound",
"NestedInterpretationMixin",
"InterpreterDependency",
"format_parameter_value",
"resove_operator_params",
Expand Down
44 changes: 32 additions & 12 deletions packages/commons/octobot_commons/dsl_interpreter/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ def extend(
{operator_class.get_name(): operator_class for operator_class in operators}
)

def create_nested(self) -> "Interpreter":
"""
Create a child interpreter with the same operator classes as this one.
"""
return Interpreter(list(self.operators_by_name.values()))

def _instantiate_operator(
self,
operator_class: typing.Type[dsl_interpreter_operator.Operator],
*args: typing.Any,
**kwargs: typing.Any,
) -> dsl_interpreter_operator.Operator:
operator_instance = operator_class(*args, **kwargs)
operator_instance.interpreter = self
return operator_instance

async def interprete(
self, expression: str
) -> dsl_interpreter_operator.ComputedOperatorParameterType:
Expand Down Expand Up @@ -217,7 +233,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
args, kwargs = parameters_util.resolve_operator_args_and_kwargs(
operator_class, args, kwargs
)
return operator_class(*args, **kwargs)
return self._instantiate_operator(operator_class, *args, **kwargs)
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown operator: {func_name}"
)
Expand All @@ -229,7 +245,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
operator_class = self.operators_by_name[op_name]
left = self._visit_node(node.left)
right = self._visit_node(node.right)
return operator_class(left, right)
return self._instantiate_operator(operator_class, left, right)
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown binary operator: {op_name}"
)
Expand All @@ -240,7 +256,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
if op_name in self.operators_by_name:
operator_class = self.operators_by_name[op_name]
operand = self._visit_node(node.operand)
return operator_class(operand)
return self._instantiate_operator(operator_class, operand)
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown unary operator: {op_name}"
)
Expand All @@ -259,7 +275,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
)
operator_class = self.operators_by_name[op_name]
right = self._visit_node(comparator)
comparisons.append(operator_class(left, right))
comparisons.append(self._instantiate_operator(operator_class, left, right))
left = right
if len(comparisons) == 1:
return comparisons[0]
Expand All @@ -268,7 +284,9 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
raise octobot_commons.errors.UnsupportedOperatorError(
f"Chained comparisons require the '{and_op_name}' operator"
)
return self.operators_by_name[and_op_name](*comparisons)
return self._instantiate_operator(
self.operators_by_name[and_op_name], *comparisons
)

if isinstance(node, (ast.Constant)):
# Literal values: numbers, strings, booleans, None
Expand All @@ -279,7 +297,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
name = node.id
if name in self.operators_by_name:
operator_class = self.operators_by_name[name]
return operator_class()
return self._instantiate_operator(operator_class)
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown name: {name}"
)
Expand All @@ -290,7 +308,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
if op_name in self.operators_by_name:
operator_class = self.operators_by_name[op_name]
operands = [self._visit_node(operand) for operand in node.values]
return operator_class(*operands)
return self._instantiate_operator(operator_class, *operands)
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown BoolOp operator: {op_name}"
)
Expand All @@ -303,7 +321,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
test = self._visit_node(node.test)
body = self._visit_node(node.body)
orelse = self._visit_node(node.orelse)
return operator_class(test, body, orelse)
return self._instantiate_operator(operator_class, test, body, orelse)
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown IfExp operator: {op_name}"
)
Expand All @@ -316,15 +334,17 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
array_or_list = self._visit_node(node.value)
index_or_slice = self._visit_node(node.slice)
context = node.ctx
return operator_class(array_or_list, index_or_slice, context)
return self._instantiate_operator(
operator_class, array_or_list, index_or_slice, context
)

if isinstance(node, ast.List):
# List: [1, 2, 3]
op_name = ast.List.__name__
if op_name in self.operators_by_name:
operator_class = self.operators_by_name[op_name]
operands = [self._visit_node(operand) for operand in node.elts]
return operator_class(*operands)
return self._instantiate_operator(operator_class, *operands)

if isinstance(node, ast.Dict):
# Dict: {"a": 1, "b": 2} or {"a": 1, **other}
Expand All @@ -351,7 +371,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
lower = self._visit_node(node.lower)
upper = self._visit_node(node.upper)
step = self._visit_node(node.step)
return operator_class(lower, upper, step)
return self._instantiate_operator(operator_class, lower, upper, step)

if isinstance(node, ast.Raise):
# Raise statement: raise exc [from cause] - maps to RaiseOperator
Expand All @@ -374,7 +394,7 @@ def _visit_node(self, node: typing.Optional[ast.AST]) -> typing.Union[
args, kwargs = parameters_util.resolve_operator_args_and_kwargs(
operator_class, args, {}
)
return operator_class(*args, **kwargs)
return self._instantiate_operator(operator_class, *args, **kwargs)
raise octobot_commons.errors.UnsupportedOperatorError(
f"Unknown operator: {op_name}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, *parameters: OperatorParameterType, **kwargs: typing.Any):
self._validate_parameters(parameters, kwargs)
self.parameters = parameters
self.kwargs = kwargs
# Injected by Interpreter._instantiate_operator; not a DSL parameter.
self.interpreter: typing.Optional[typing.Any] = None

@staticmethod
def get_name() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
ProcessBoundOperatorMixin,
is_process_bound,
)
from octobot_commons.dsl_interpreter.operators.nested_interpretation_mixin import (
NestedInterpretationMixin,
)

__all__ = [
"BinaryOperator",
Expand All @@ -79,4 +82,5 @@
"OperatorSignals",
"ProcessBoundOperatorMixin",
"is_process_bound",
"NestedInterpretationMixin",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Drakkar-Software OctoBot-Commons
# Copyright (c) Drakkar-Software, All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library.
import octobot_commons.errors
import octobot_commons.dsl_interpreter.operator as dsl_interpreter_operator


class NestedInterpretationMixin:
"""
Mixin for operators that interpret nested DSL strings using the parent
interpreter's operator registry (see Operator.interpreter).
"""

async def interprete_in_nested_interpreter(
self, expression: str
) -> dsl_interpreter_operator.ComputedOperatorParameterType:
"""
Interprets the given expression in the nested interpreter.
"""
if self.interpreter is None: # type: ignore
raise octobot_commons.errors.DSLInterpreterError(
"Cannot interpret nested expression: no parent interpreter was provided"
)
nested_interpreter = self.interpreter.create_nested() # type: ignore
return await nested_interpreter.interprete(expression)
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Drakkar-Software OctoBot-Commons
# Copyright (c) Drakkar-Software, All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library.
import ast

import pytest

import octobot_commons.dsl_interpreter as dsl_interpreter
import octobot_commons.errors


class AddOperator(dsl_interpreter.BinaryOperator):
@staticmethod
def get_name() -> str:
return ast.Add.__name__

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
left, right = self.get_computed_left_and_right_parameters()
return left + right


class NestedEchoOperator(
dsl_interpreter.CallOperator,
dsl_interpreter.NestedInterpretationMixin,
):
@staticmethod
def get_name() -> str:
return "nested_echo"

def compute(self) -> dsl_interpreter.ComputedOperatorParameterType:
return "unused"


@pytest.fixture
def interpreter():
return dsl_interpreter.Interpreter(
dsl_interpreter.get_all_operators() + [AddOperator]
)


def test_operator_receives_interpreter_on_instantiation(interpreter):
interpreter.prepare("1 + (2 + 3)")
top_operator = interpreter.get_top_operator()
assert isinstance(top_operator, AddOperator)
assert top_operator.interpreter is interpreter
inner_add = top_operator.parameters[1]
assert isinstance(inner_add, AddOperator)
assert inner_add.interpreter is interpreter


def test_create_nested_reuses_operator_classes(interpreter):
interpreter.extend([NestedEchoOperator])
nested_interpreter = interpreter.create_nested()
assert set(nested_interpreter.operators_by_name.keys()) == set(
interpreter.operators_by_name.keys()
)
assert nested_interpreter is not interpreter


@pytest.mark.asyncio
async def test_interprete_nested_raises_without_interpreter():
operator = NestedEchoOperator()
with pytest.raises(octobot_commons.errors.DSLInterpreterError, match="no parent interpreter"):
await operator.interprete_in_nested_interpreter("1 + 1")
Loading
Loading