Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions comfy_extras/nodes_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Math expression node using simpleeval for safe evaluation.

Provides a ComfyMathExpression node that evaluates math expressions
against dynamically-grown numeric inputs.
"""

from __future__ import annotations

import math
import string

from simpleeval import simple_eval
from typing_extensions import override

from comfy_api.latest import ComfyExtension, io


MAX_EXPONENT = 4000


def _variadic_sum(*args):
"""Support both sum(values) and sum(a, b, c)."""
if len(args) == 1 and hasattr(args[0], "__iter__"):
return sum(args[0])
return sum(args)


def _safe_pow(base, exp):
"""Wrap pow() with an exponent cap to prevent DoS via huge exponents.

The ** operator is already guarded by simpleeval's safe_power, but
pow() as a callable bypasses that guard.
"""
if abs(exp) > MAX_EXPONENT:
raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})")
return pow(base, exp)


MATH_FUNCTIONS = {
"sum": _variadic_sum,
"min": min,
"max": max,
"abs": abs,
"round": round,
"pow": _safe_pow,
"sqrt": math.sqrt,
"ceil": math.ceil,
"floor": math.floor,
"log": math.log,
"log2": math.log2,
"log10": math.log10,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"int": int,
"float": float,
}


class MathExpressionNode(io.ComfyNode):
"""Evaluates a math expression against dynamically-grown inputs."""

@classmethod
def define_schema(cls) -> io.Schema:
autogrow = io.Autogrow.TemplateNames(
input=io.MultiType.Input("value", [io.Float, io.Int]),
names=list(string.ascii_lowercase),
min=1,
)
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="math",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",
],
inputs=[
io.String.Input("expression", default="a + b", multiline=True),
io.Autogrow.Input("values", template=autogrow),
],
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
],
)

@classmethod
def execute(
cls, expression: str, values: io.Autogrow.Type
) -> io.NodeOutput:
if not expression.strip():
raise ValueError("Expression cannot be empty.")

context: dict = dict(values)
context["values"] = list(values.values())

result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
# bool check must come first because bool is a subclass of int in Python
if isinstance(result, bool) or not isinstance(result, (int, float)):
raise ValueError(
f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}"
)
if not math.isfinite(result):
raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}"
)
return io.NodeOutput(float(result), int(result))


class MathExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [MathExpressionNode]


async def comfy_entrypoint() -> MathExtension:
return MathExtension()
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,6 +2449,7 @@ async def init_builtin_extra_nodes():
"nodes_replacements.py",
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
]

import_failed = []
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.7
requests
simpleeval>=1.0

#non essential dependencies:
kornia>=0.7.1
Expand Down
197 changes: 197 additions & 0 deletions tests-unit/comfy_extras_test/nodes_math_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import math

import pytest
from collections import OrderedDict
from unittest.mock import patch, MagicMock

mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
mock_server = MagicMock()

with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
from comfy_extras.nodes_math import MathExpressionNode


class TestMathExpressionExecute:
@staticmethod
def _exec(expression: str, **kwargs) -> object:
values = OrderedDict(kwargs)
return MathExpressionNode.execute(expression, values)

def test_addition(self):
result = self._exec("a + b", a=3, b=4)
assert result[0] == 7.0
assert result[1] == 7

def test_subtraction(self):
result = self._exec("a - b", a=10, b=3)
assert result[0] == 7.0
assert result[1] == 7

def test_multiplication(self):
result = self._exec("a * b", a=3, b=5)
assert result[0] == 15.0
assert result[1] == 15

def test_division(self):
result = self._exec("a / b", a=10, b=4)
assert result[0] == 2.5
assert result[1] == 2

def test_single_input(self):
result = self._exec("a * 2", a=5)
assert result[0] == 10.0
assert result[1] == 10

def test_three_inputs(self):
result = self._exec("a + b + c", a=1, b=2, c=3)
assert result[0] == 6.0
assert result[1] == 6

def test_float_inputs(self):
result = self._exec("a + b", a=1.5, b=2.5)
assert result[0] == 4.0
assert result[1] == 4

def test_mixed_int_float_inputs(self):
result = self._exec("a * b", a=1024, b=1.5)
assert result[0] == 1536.0
assert result[1] == 1536

def test_mixed_resolution_scale(self):
result = self._exec("a * b", a=512, b=0.75)
assert result[0] == 384.0
assert result[1] == 384

def test_sum_values_array(self):
result = self._exec("sum(values)", a=1, b=2, c=3)
assert result[0] == 6.0

def test_sum_variadic(self):
result = self._exec("sum(a, b, c)", a=1, b=2, c=3)
assert result[0] == 6.0

def test_min_values(self):
result = self._exec("min(values)", a=5, b=2, c=8)
assert result[0] == 2.0

def test_max_values(self):
result = self._exec("max(values)", a=5, b=2, c=8)
assert result[0] == 8.0

def test_abs_function(self):
result = self._exec("abs(a)", a=-7)
assert result[0] == 7.0
assert result[1] == 7

def test_sqrt(self):
result = self._exec("sqrt(a)", a=16)
assert result[0] == 4.0
assert result[1] == 4

def test_ceil(self):
result = self._exec("ceil(a)", a=2.3)
assert result[0] == 3.0
assert result[1] == 3

def test_floor(self):
result = self._exec("floor(a)", a=2.7)
assert result[0] == 2.0
assert result[1] == 2

def test_sin(self):
result = self._exec("sin(a)", a=0)
assert result[0] == 0.0

def test_log10(self):
result = self._exec("log10(a)", a=100)
assert result[0] == 2.0
assert result[1] == 2

def test_float_output_type(self):
result = self._exec("a + b", a=1, b=2)
assert isinstance(result[0], float)

def test_int_output_type(self):
result = self._exec("a + b", a=1, b=2)
assert isinstance(result[1], int)

def test_non_numeric_result_raises(self):
with pytest.raises(ValueError, match="must evaluate to a numeric result"):
self._exec("'hello'", a=42)

def test_undefined_function_raises(self):
with pytest.raises(Exception, match="not defined"):
self._exec("str(a)", a=42)

def test_boolean_result_raises(self):
with pytest.raises(ValueError, match="got bool"):
self._exec("a > b", a=5, b=3)

def test_empty_expression_raises(self):
with pytest.raises(ValueError, match="Expression cannot be empty"):
self._exec("", a=1)

def test_whitespace_only_expression_raises(self):
with pytest.raises(ValueError, match="Expression cannot be empty"):
self._exec(" ", a=1)

# --- Missing function coverage (round, pow, log, log2, cos, tan) ---

def test_round(self):
result = self._exec("round(a)", a=2.7)
assert result[0] == 3.0
assert result[1] == 3

def test_round_with_ndigits(self):
result = self._exec("round(a, 2)", a=3.14159)
assert result[0] == pytest.approx(3.14)

def test_pow(self):
result = self._exec("pow(a, b)", a=2, b=10)
assert result[0] == 1024.0
assert result[1] == 1024

def test_log(self):
result = self._exec("log(a)", a=math.e)
assert result[0] == pytest.approx(1.0)

def test_log2(self):
result = self._exec("log2(a)", a=8)
assert result[0] == pytest.approx(3.0)

def test_cos(self):
result = self._exec("cos(a)", a=0)
assert result[0] == 1.0

def test_tan(self):
result = self._exec("tan(a)", a=0)
assert result[0] == 0.0

# --- int/float converter functions ---

def test_int_converter(self):
result = self._exec("int(a / b)", a=7, b=2)
assert result[1] == 3

def test_float_converter(self):
result = self._exec("float(a)", a=5)
assert result[0] == 5.0

# --- Error path tests ---

def test_division_by_zero_raises(self):
with pytest.raises(ZeroDivisionError):
self._exec("a / b", a=1, b=0)

def test_sqrt_negative_raises(self):
with pytest.raises(ValueError, match="math domain error"):
self._exec("sqrt(a)", a=-1)

def test_overflow_inf_raises(self):
with pytest.raises(ValueError, match="non-finite result"):
self._exec("a * b", a=1e308, b=10)

def test_pow_huge_exponent_raises(self):
with pytest.raises(ValueError, match="Exponent .* exceeds maximum"):
self._exec("pow(a, b)", a=10, b=10000000)
Loading