Skip to content

Commit 68f0d35

Browse files
authored
Add support for VIDEO as a built-in type (Comfy-Org#7844)
* Add basic support for videos as types This PR adds support for VIDEO as first-class types. In order to avoid unnecessary costs, VIDEO outputs must implement the `VideoInput` ABC, but their implementation details can vary. Included are two implementations of this type which can be returned by other nodes: * `VideoFromFile` - Created with either a path on disk (as a string) or a `io.BytesIO` containing the contents of a file in a supported format (like .mp4). This implementation won't actually load the video unless necessary. It will also avoid re-encoding when saving if possible. * `VideoFromComponents` - Created from an image tensor and an optional audio tensor. Currently, only h264 encoded videos in .mp4 containers are supported for saving, but the plan is to add additional encodings/containers in the near future (particularly .webm). * Add optimization to avoid parsing entire video * Improve type declarations to reduce warnings * Make sure bytesIO objects can be read many times * Fix a potential issue when saving long videos * Fix incorrect type annotation * Add a `LoadVideo` node to make testing easier * Refactor new types out of the base comfy folder I've created a new `comfy_api` top-level module. The intention is that anything within this folder would be covered by semver-style versioning that would allow custom nodes to rely on them not introducing breaking changes. * Fix linting issue
1 parent 83d0471 commit 68f0d35

10 files changed

Lines changed: 532 additions & 8 deletions

File tree

comfy/comfy_types/node_typing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class IO(StrEnum):
4848
FACE_ANALYSIS = "FACE_ANALYSIS"
4949
BBOX = "BBOX"
5050
SEGS = "SEGS"
51+
VIDEO = "VIDEO"
5152

5253
ANY = "*"
5354
"""Always matches any type, but at a price.
@@ -273,7 +274,7 @@ def INPUT_TYPES(s) -> InputTypeDict:
273274
274275
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
275276
"""
276-
OUTPUT_IS_LIST: tuple[bool]
277+
OUTPUT_IS_LIST: tuple[bool, ...]
277278
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
278279
279280
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
@@ -292,7 +293,7 @@ def INPUT_TYPES(s) -> InputTypeDict:
292293
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
293294
"""
294295

295-
RETURN_TYPES: tuple[IO]
296+
RETURN_TYPES: tuple[IO, ...]
296297
"""A tuple representing the outputs of this node.
297298
298299
Usage::
@@ -301,12 +302,12 @@ def INPUT_TYPES(s) -> InputTypeDict:
301302
302303
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
303304
"""
304-
RETURN_NAMES: tuple[str]
305+
RETURN_NAMES: tuple[str, ...]
305306
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
306307
307308
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
308309
"""
309-
OUTPUT_TOOLTIPS: tuple[str]
310+
OUTPUT_TOOLTIPS: tuple[str, ...]
310311
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
311312
FUNCTION: str
312313
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`

comfy_api/input/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .basic_types import ImageInput, AudioInput
2+
from .video_types import VideoInput
3+
4+
__all__ = [
5+
"ImageInput",
6+
"AudioInput",
7+
"VideoInput",
8+
]

comfy_api/input/basic_types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
from typing import TypedDict
3+
4+
ImageInput = torch.Tensor
5+
"""
6+
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
7+
"""
8+
9+
class AudioInput(TypedDict):
10+
"""
11+
TypedDict representing audio input.
12+
"""
13+
14+
waveform: torch.Tensor
15+
"""
16+
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
17+
"""
18+
19+
sample_rate: int
20+

comfy_api/input/video_types.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
from abc import ABC, abstractmethod
3+
from typing import Optional
4+
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
5+
6+
class VideoInput(ABC):
7+
"""
8+
Abstract base class for video input types.
9+
"""
10+
11+
@abstractmethod
12+
def get_components(self) -> VideoComponents:
13+
"""
14+
Abstract method to get the video components (images, audio, and frame rate).
15+
16+
Returns:
17+
VideoComponents containing images, audio, and frame rate
18+
"""
19+
pass
20+
21+
@abstractmethod
22+
def save_to(
23+
self,
24+
path: str,
25+
format: VideoContainer = VideoContainer.AUTO,
26+
codec: VideoCodec = VideoCodec.AUTO,
27+
metadata: Optional[dict] = None
28+
):
29+
"""
30+
Abstract method to save the video input to a file.
31+
"""
32+
pass
33+
34+
# Provide a default implementation, but subclasses can provide optimized versions
35+
# if possible.
36+
def get_dimensions(self) -> tuple[int, int]:
37+
"""
38+
Returns the dimensions of the video input.
39+
40+
Returns:
41+
Tuple of (width, height)
42+
"""
43+
components = self.get_components()
44+
return components.images.shape[2], components.images.shape[1]
45+

comfy_api/input_impl/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .video_types import VideoFromFile, VideoFromComponents
2+
3+
__all__ = [
4+
# Implementations
5+
"VideoFromFile",
6+
"VideoFromComponents",
7+
]
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
from __future__ import annotations
2+
from av.container import InputContainer
3+
from av.subtitles.stream import SubtitleStream
4+
from fractions import Fraction
5+
from typing import Optional
6+
from comfy_api.input import AudioInput
7+
import av
8+
import io
9+
import json
10+
import numpy as np
11+
import torch
12+
from comfy_api.input import VideoInput
13+
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
14+
15+
class VideoFromFile(VideoInput):
16+
"""
17+
Class representing video input from a file.
18+
"""
19+
20+
def __init__(self, file: str | io.BytesIO):
21+
"""
22+
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
23+
containing the file contents.
24+
"""
25+
self.__file = file
26+
27+
def get_dimensions(self) -> tuple[int, int]:
28+
"""
29+
Returns the dimensions of the video input.
30+
31+
Returns:
32+
Tuple of (width, height)
33+
"""
34+
if isinstance(self.__file, io.BytesIO):
35+
self.__file.seek(0) # Reset the BytesIO object to the beginning
36+
with av.open(self.__file, mode='r') as container:
37+
for stream in container.streams:
38+
if stream.type == 'video':
39+
assert isinstance(stream, av.VideoStream)
40+
return stream.width, stream.height
41+
raise ValueError(f"No video stream found in file '{self.__file}'")
42+
43+
def get_components_internal(self, container: InputContainer) -> VideoComponents:
44+
# Get video frames
45+
frames = []
46+
for frame in container.decode(video=0):
47+
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
48+
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
49+
frames.append(img)
50+
51+
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
52+
53+
# Get frame rate
54+
video_stream = next(s for s in container.streams if s.type == 'video')
55+
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
56+
57+
# Get audio if available
58+
audio = None
59+
try:
60+
container.seek(0) # Reset the container to the beginning
61+
for stream in container.streams:
62+
if stream.type != 'audio':
63+
continue
64+
assert isinstance(stream, av.AudioStream)
65+
audio_frames = []
66+
for packet in container.demux(stream):
67+
for frame in packet.decode():
68+
assert isinstance(frame, av.AudioFrame)
69+
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
70+
if len(audio_frames) > 0:
71+
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
72+
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
73+
audio = AudioInput({
74+
"waveform": audio_tensor,
75+
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
76+
})
77+
except StopIteration:
78+
pass # No audio stream
79+
80+
metadata = container.metadata
81+
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
82+
83+
def get_components(self) -> VideoComponents:
84+
if isinstance(self.__file, io.BytesIO):
85+
self.__file.seek(0) # Reset the BytesIO object to the beginning
86+
with av.open(self.__file, mode='r') as container:
87+
return self.get_components_internal(container)
88+
raise ValueError(f"No video stream found in file '{self.__file}'")
89+
90+
def save_to(
91+
self,
92+
path: str,
93+
format: VideoContainer = VideoContainer.AUTO,
94+
codec: VideoCodec = VideoCodec.AUTO,
95+
metadata: Optional[dict] = None
96+
):
97+
if isinstance(self.__file, io.BytesIO):
98+
self.__file.seek(0) # Reset the BytesIO object to the beginning
99+
with av.open(self.__file, mode='r') as container:
100+
container_format = container.format.name
101+
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
102+
reuse_streams = True
103+
if format != VideoContainer.AUTO and format not in container_format.split(","):
104+
reuse_streams = False
105+
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
106+
reuse_streams = False
107+
108+
if not reuse_streams:
109+
components = self.get_components_internal(container)
110+
video = VideoFromComponents(components)
111+
return video.save_to(
112+
path,
113+
format=format,
114+
codec=codec,
115+
metadata=metadata
116+
)
117+
118+
streams = container.streams
119+
with av.open(path, mode='w', options={"movflags": "use_metadata_tags"}) as output_container:
120+
# Copy over the original metadata
121+
for key, value in container.metadata.items():
122+
if metadata is None or key not in metadata:
123+
output_container.metadata[key] = value
124+
125+
# Add our new metadata
126+
if metadata is not None:
127+
for key, value in metadata.items():
128+
if isinstance(value, str):
129+
output_container.metadata[key] = value
130+
else:
131+
output_container.metadata[key] = json.dumps(value)
132+
133+
# Add streams to the new container
134+
stream_map = {}
135+
for stream in streams:
136+
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
137+
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
138+
stream_map[stream] = out_stream
139+
140+
# Write packets to the new container
141+
for packet in container.demux():
142+
if packet.stream in stream_map and packet.dts is not None:
143+
packet.stream = stream_map[packet.stream]
144+
output_container.mux(packet)
145+
146+
class VideoFromComponents(VideoInput):
147+
"""
148+
Class representing video input from tensors.
149+
"""
150+
151+
def __init__(self, components: VideoComponents):
152+
self.__components = components
153+
154+
def get_components(self) -> VideoComponents:
155+
return VideoComponents(
156+
images=self.__components.images,
157+
audio=self.__components.audio,
158+
frame_rate=self.__components.frame_rate
159+
)
160+
161+
def save_to(
162+
self,
163+
path: str,
164+
format: VideoContainer = VideoContainer.AUTO,
165+
codec: VideoCodec = VideoCodec.AUTO,
166+
metadata: Optional[dict] = None
167+
):
168+
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
169+
raise ValueError("Only MP4 format is supported for now")
170+
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
171+
raise ValueError("Only H264 codec is supported for now")
172+
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
173+
# Add metadata before writing any streams
174+
if metadata is not None:
175+
for key, value in metadata.items():
176+
output.metadata[key] = json.dumps(value)
177+
178+
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
179+
# Create a video stream
180+
video_stream = output.add_stream('h264', rate=frame_rate)
181+
video_stream.width = self.__components.images.shape[2]
182+
video_stream.height = self.__components.images.shape[1]
183+
video_stream.pix_fmt = 'yuv420p'
184+
185+
# Create an audio stream
186+
audio_sample_rate = 1
187+
audio_stream: Optional[av.AudioStream] = None
188+
if self.__components.audio:
189+
audio_sample_rate = int(self.__components.audio['sample_rate'])
190+
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
191+
audio_stream.sample_rate = audio_sample_rate
192+
audio_stream.format = 'fltp'
193+
194+
# Encode video
195+
for i, frame in enumerate(self.__components.images):
196+
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
197+
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
198+
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
199+
packet = video_stream.encode(frame)
200+
output.mux(packet)
201+
202+
# Flush video
203+
packet = video_stream.encode(None)
204+
output.mux(packet)
205+
206+
if audio_stream and self.__components.audio:
207+
# Encode audio
208+
samples_per_frame = int(audio_sample_rate / frame_rate)
209+
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
210+
for i in range(num_frames):
211+
start = i * samples_per_frame
212+
end = start + samples_per_frame
213+
# TODO(Feature) - Add support for stereo audio
214+
chunk = self.__components.audio['waveform'][0, 0, start:end].unsqueeze(0).numpy()
215+
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
216+
audio_frame.sample_rate = audio_sample_rate
217+
audio_frame.pts = i * samples_per_frame
218+
for packet in audio_stream.encode(audio_frame):
219+
output.mux(packet)
220+
221+
# Flush audio
222+
for packet in audio_stream.encode(None):
223+
output.mux(packet)
224+

comfy_api/util/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .video_types import VideoContainer, VideoCodec, VideoComponents
2+
3+
__all__ = [
4+
# Utility Types
5+
"VideoContainer",
6+
"VideoCodec",
7+
"VideoComponents",
8+
]

0 commit comments

Comments
 (0)