Skip to content

Commit 8e73678

Browse files
CURVE node (Comfy-Org#12757)
* CURVE node * remove curve to sigmas node * feat: add CurveInput ABC with MonotoneCubicCurve implementation (Comfy-Org#12986) CurveInput is an abstract base class so future curve representations (bezier, LUT-based, analytical functions) can be added without breaking downstream nodes that type-check against CurveInput. MonotoneCubicCurve is the concrete implementation that: - Mirrors frontend createMonotoneInterpolator (curveUtils.ts) exactly - Pre-computes slopes as numpy arrays at construction time - Provides vectorised interp_array() using numpy for batch evaluation - interp() for single-value evaluation - to_lut() for generating lookup tables CurveEditor node wraps raw widget points in MonotoneCubicCurve. * linear curve * refactor: move CurveEditor to comfy_extras/nodes_curve.py with V3 schema * feat: add HISTOGRAM type and histogram support to CurveEditor * code improve --------- Co-authored-by: Christian Byrne <cbyrne@comfy.org>
1 parent c2862b2 commit 8e73678

File tree

6 files changed

+292
-3
lines changed

6 files changed

+292
-3
lines changed

comfy_api/input/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
MaskInput,
66
LatentInput,
77
VideoInput,
8+
CurvePoint,
9+
CurveInput,
10+
MonotoneCubicCurve,
11+
LinearCurve,
812
)
913

1014
__all__ = [
@@ -13,4 +17,8 @@
1317
"MaskInput",
1418
"LatentInput",
1519
"VideoInput",
20+
"CurvePoint",
21+
"CurveInput",
22+
"MonotoneCubicCurve",
23+
"LinearCurve",
1624
]

comfy_api/latest/_input/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
2+
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
23
from .video_types import VideoInput
34

45
__all__ = [
@@ -7,4 +8,8 @@
78
"VideoInput",
89
"MaskInput",
910
"LatentInput",
11+
"CurvePoint",
12+
"CurveInput",
13+
"MonotoneCubicCurve",
14+
"LinearCurve",
1015
]
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import math
5+
from abc import ABC, abstractmethod
6+
import numpy as np
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
CurvePoint = tuple[float, float]
12+
13+
14+
class CurveInput(ABC):
15+
"""Abstract base class for curve inputs.
16+
17+
Subclasses represent different curve representations (control-point
18+
interpolation, analytical functions, LUT-based, etc.) while exposing a
19+
uniform evaluation interface to downstream nodes.
20+
"""
21+
22+
@property
23+
@abstractmethod
24+
def points(self) -> list[CurvePoint]:
25+
"""The control points that define this curve."""
26+
27+
@abstractmethod
28+
def interp(self, x: float) -> float:
29+
"""Evaluate the curve at a single *x* value in [0, 1]."""
30+
31+
def interp_array(self, xs: np.ndarray) -> np.ndarray:
32+
"""Vectorised evaluation over a numpy array of x values.
33+
34+
Subclasses should override this for better performance. The default
35+
falls back to scalar ``interp`` calls.
36+
"""
37+
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
38+
39+
def to_lut(self, size: int = 256) -> np.ndarray:
40+
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
41+
return self.interp_array(np.linspace(0.0, 1.0, size))
42+
43+
@staticmethod
44+
def from_raw(data) -> CurveInput:
45+
"""Convert raw curve data (dict or point list) to a CurveInput instance.
46+
47+
Accepts:
48+
- A ``CurveInput`` instance (returned as-is).
49+
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
50+
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
51+
"""
52+
if isinstance(data, CurveInput):
53+
return data
54+
if isinstance(data, dict):
55+
raw_points = data["points"]
56+
interpolation = data.get("interpolation", "monotone_cubic")
57+
else:
58+
raw_points = data
59+
interpolation = "monotone_cubic"
60+
points = [(float(x), float(y)) for x, y in raw_points]
61+
if interpolation == "linear":
62+
return LinearCurve(points)
63+
if interpolation != "monotone_cubic":
64+
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
65+
return MonotoneCubicCurve(points)
66+
67+
68+
class MonotoneCubicCurve(CurveInput):
69+
"""Monotone cubic Hermite interpolation over control points.
70+
71+
Mirrors the frontend ``createMonotoneInterpolator`` in
72+
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
73+
backend evaluation matches the editor preview exactly.
74+
75+
All heavy work (sorting, slope computation) happens once at construction.
76+
``interp_array`` is fully vectorised with numpy.
77+
"""
78+
79+
def __init__(self, control_points: list[CurvePoint]):
80+
sorted_pts = sorted(control_points, key=lambda p: p[0])
81+
self._points = [(float(x), float(y)) for x, y in sorted_pts]
82+
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
83+
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
84+
self._slopes = self._compute_slopes()
85+
86+
@property
87+
def points(self) -> list[CurvePoint]:
88+
return list(self._points)
89+
90+
def _compute_slopes(self) -> np.ndarray:
91+
xs, ys = self._xs, self._ys
92+
n = len(xs)
93+
if n < 2:
94+
return np.zeros(n, dtype=np.float64)
95+
96+
dx = np.diff(xs)
97+
dy = np.diff(ys)
98+
dx_safe = np.where(dx == 0, 1.0, dx)
99+
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
100+
101+
slopes = np.empty(n, dtype=np.float64)
102+
slopes[0] = deltas[0]
103+
slopes[-1] = deltas[-1]
104+
for i in range(1, n - 1):
105+
if deltas[i - 1] * deltas[i] <= 0:
106+
slopes[i] = 0.0
107+
else:
108+
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
109+
110+
for i in range(n - 1):
111+
if deltas[i] == 0:
112+
slopes[i] = 0.0
113+
slopes[i + 1] = 0.0
114+
else:
115+
alpha = slopes[i] / deltas[i]
116+
beta = slopes[i + 1] / deltas[i]
117+
s = alpha * alpha + beta * beta
118+
if s > 9:
119+
t = 3 / math.sqrt(s)
120+
slopes[i] = t * alpha * deltas[i]
121+
slopes[i + 1] = t * beta * deltas[i]
122+
return slopes
123+
124+
def interp(self, x: float) -> float:
125+
xs, ys, slopes = self._xs, self._ys, self._slopes
126+
n = len(xs)
127+
if n == 0:
128+
return 0.0
129+
if n == 1:
130+
return float(ys[0])
131+
if x <= xs[0]:
132+
return float(ys[0])
133+
if x >= xs[-1]:
134+
return float(ys[-1])
135+
136+
hi = int(np.searchsorted(xs, x, side='right'))
137+
hi = min(hi, n - 1)
138+
lo = hi - 1
139+
140+
dx = xs[hi] - xs[lo]
141+
if dx == 0:
142+
return float(ys[lo])
143+
144+
t = (x - xs[lo]) / dx
145+
t2 = t * t
146+
t3 = t2 * t
147+
h00 = 2 * t3 - 3 * t2 + 1
148+
h10 = t3 - 2 * t2 + t
149+
h01 = -2 * t3 + 3 * t2
150+
h11 = t3 - t2
151+
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
152+
153+
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
154+
"""Fully vectorised evaluation using numpy."""
155+
xs, ys, slopes = self._xs, self._ys, self._slopes
156+
n = len(xs)
157+
if n == 0:
158+
return np.zeros_like(xs_in, dtype=np.float64)
159+
if n == 1:
160+
return np.full_like(xs_in, ys[0], dtype=np.float64)
161+
162+
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
163+
lo = hi - 1
164+
165+
dx = xs[hi] - xs[lo]
166+
dx_safe = np.where(dx == 0, 1.0, dx)
167+
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
168+
t2 = t * t
169+
t3 = t2 * t
170+
171+
h00 = 2 * t3 - 3 * t2 + 1
172+
h10 = t3 - 2 * t2 + t
173+
h01 = -2 * t3 + 3 * t2
174+
h11 = t3 - t2
175+
176+
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
177+
result = np.where(xs_in <= xs[0], ys[0], result)
178+
result = np.where(xs_in >= xs[-1], ys[-1], result)
179+
return result
180+
181+
def __repr__(self) -> str:
182+
return f"MonotoneCubicCurve(points={self._points})"
183+
184+
185+
class LinearCurve(CurveInput):
186+
"""Piecewise linear interpolation over control points.
187+
188+
Mirrors the frontend ``createLinearInterpolator`` in
189+
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
190+
"""
191+
192+
def __init__(self, control_points: list[CurvePoint]):
193+
sorted_pts = sorted(control_points, key=lambda p: p[0])
194+
self._points = [(float(x), float(y)) for x, y in sorted_pts]
195+
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
196+
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
197+
198+
@property
199+
def points(self) -> list[CurvePoint]:
200+
return list(self._points)
201+
202+
def interp(self, x: float) -> float:
203+
xs, ys = self._xs, self._ys
204+
n = len(xs)
205+
if n == 0:
206+
return 0.0
207+
if n == 1:
208+
return float(ys[0])
209+
return float(np.interp(x, xs, ys))
210+
211+
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
212+
if len(self._xs) == 0:
213+
return np.zeros_like(xs_in, dtype=np.float64)
214+
if len(self._xs) == 1:
215+
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
216+
return np.interp(xs_in, self._xs, self._ys)
217+
218+
def __repr__(self) -> str:
219+
return f"LinearCurve(points={self._points})"

comfy_api/latest/_io.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from comfy.samplers import CFGGuider, Sampler
2424
from comfy.sd import CLIP, VAE
2525
from comfy.sd import StyleModel as StyleModel_
26-
from comfy_api.input import VideoInput
26+
from comfy_api.input import VideoInput, CurveInput as CurveInput_
2727
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
2828
prune_dict, shallow_clone_class)
2929
from comfy_execution.graph_utils import ExecutionBlocker
@@ -1242,8 +1242,9 @@ def as_dict(self):
12421242

12431243
@comfytype(io_type="CURVE")
12441244
class Curve(ComfyTypeIO):
1245-
CurvePoint = tuple[float, float]
1246-
Type = list[CurvePoint]
1245+
from comfy_api.input import CurvePoint
1246+
if TYPE_CHECKING:
1247+
Type = CurveInput_
12471248

12481249
class Input(WidgetInput):
12491250
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
@@ -1252,6 +1253,18 @@ def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str
12521253
if default is None:
12531254
self.default = [(0.0, 0.0), (1.0, 1.0)]
12541255

1256+
def as_dict(self):
1257+
d = super().as_dict()
1258+
if self.default is not None:
1259+
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
1260+
return d
1261+
1262+
1263+
@comfytype(io_type="HISTOGRAM")
1264+
class Histogram(ComfyTypeIO):
1265+
"""A histogram represented as a list of bin counts."""
1266+
Type = list[int]
1267+
12551268

12561269
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
12571270
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
@@ -2240,5 +2253,6 @@ def as_dict(self):
22402253
"PriceBadge",
22412254
"BoundingBox",
22422255
"Curve",
2256+
"Histogram",
22432257
"NodeReplace",
22442258
]

comfy_extras/nodes_curve.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from __future__ import annotations
2+
3+
from comfy_api.latest import ComfyExtension, io
4+
from comfy_api.input import CurveInput
5+
from typing_extensions import override
6+
7+
8+
class CurveEditor(io.ComfyNode):
9+
@classmethod
10+
def define_schema(cls):
11+
return io.Schema(
12+
node_id="CurveEditor",
13+
display_name="Curve Editor",
14+
category="utils",
15+
inputs=[
16+
io.Curve.Input("curve"),
17+
io.Histogram.Input("histogram", optional=True),
18+
],
19+
outputs=[
20+
io.Curve.Output("curve"),
21+
],
22+
)
23+
24+
@classmethod
25+
def execute(cls, curve, histogram=None) -> io.NodeOutput:
26+
result = CurveInput.from_raw(curve)
27+
28+
ui = {}
29+
if histogram is not None:
30+
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
31+
32+
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
33+
34+
35+
class CurveExtension(ComfyExtension):
36+
@override
37+
async def get_node_list(self):
38+
return [CurveEditor]
39+
40+
41+
async def comfy_entrypoint():
42+
return CurveExtension()

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2455,6 +2455,7 @@ async def init_builtin_extra_nodes():
24552455
"nodes_sdpose.py",
24562456
"nodes_math.py",
24572457
"nodes_painter.py",
2458+
"nodes_curve.py",
24582459
]
24592460

24602461
import_failed = []

0 commit comments

Comments
 (0)