|
| 1 | +from __future__ import annotations |
| 2 | +from av.container import InputContainer |
| 3 | +from av.subtitles.stream import SubtitleStream |
| 4 | +from fractions import Fraction |
| 5 | +from typing import Optional |
| 6 | +from comfy_api.input import AudioInput |
| 7 | +import av |
| 8 | +import io |
| 9 | +import json |
| 10 | +import numpy as np |
| 11 | +import torch |
| 12 | +from comfy_api.input import VideoInput |
| 13 | +from comfy_api.util import VideoContainer, VideoCodec, VideoComponents |
| 14 | + |
| 15 | +class VideoFromFile(VideoInput): |
| 16 | + """ |
| 17 | + Class representing video input from a file. |
| 18 | + """ |
| 19 | + |
| 20 | + def __init__(self, file: str | io.BytesIO): |
| 21 | + """ |
| 22 | + Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object |
| 23 | + containing the file contents. |
| 24 | + """ |
| 25 | + self.__file = file |
| 26 | + |
| 27 | + def get_dimensions(self) -> tuple[int, int]: |
| 28 | + """ |
| 29 | + Returns the dimensions of the video input. |
| 30 | +
|
| 31 | + Returns: |
| 32 | + Tuple of (width, height) |
| 33 | + """ |
| 34 | + if isinstance(self.__file, io.BytesIO): |
| 35 | + self.__file.seek(0) # Reset the BytesIO object to the beginning |
| 36 | + with av.open(self.__file, mode='r') as container: |
| 37 | + for stream in container.streams: |
| 38 | + if stream.type == 'video': |
| 39 | + assert isinstance(stream, av.VideoStream) |
| 40 | + return stream.width, stream.height |
| 41 | + raise ValueError(f"No video stream found in file '{self.__file}'") |
| 42 | + |
| 43 | + def get_components_internal(self, container: InputContainer) -> VideoComponents: |
| 44 | + # Get video frames |
| 45 | + frames = [] |
| 46 | + for frame in container.decode(video=0): |
| 47 | + img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) |
| 48 | + img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) |
| 49 | + frames.append(img) |
| 50 | + |
| 51 | + images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) |
| 52 | + |
| 53 | + # Get frame rate |
| 54 | + video_stream = next(s for s in container.streams if s.type == 'video') |
| 55 | + frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) |
| 56 | + |
| 57 | + # Get audio if available |
| 58 | + audio = None |
| 59 | + try: |
| 60 | + container.seek(0) # Reset the container to the beginning |
| 61 | + for stream in container.streams: |
| 62 | + if stream.type != 'audio': |
| 63 | + continue |
| 64 | + assert isinstance(stream, av.AudioStream) |
| 65 | + audio_frames = [] |
| 66 | + for packet in container.demux(stream): |
| 67 | + for frame in packet.decode(): |
| 68 | + assert isinstance(frame, av.AudioFrame) |
| 69 | + audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) |
| 70 | + if len(audio_frames) > 0: |
| 71 | + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) |
| 72 | + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) |
| 73 | + audio = AudioInput({ |
| 74 | + "waveform": audio_tensor, |
| 75 | + "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, |
| 76 | + }) |
| 77 | + except StopIteration: |
| 78 | + pass # No audio stream |
| 79 | + |
| 80 | + metadata = container.metadata |
| 81 | + return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) |
| 82 | + |
| 83 | + def get_components(self) -> VideoComponents: |
| 84 | + if isinstance(self.__file, io.BytesIO): |
| 85 | + self.__file.seek(0) # Reset the BytesIO object to the beginning |
| 86 | + with av.open(self.__file, mode='r') as container: |
| 87 | + return self.get_components_internal(container) |
| 88 | + raise ValueError(f"No video stream found in file '{self.__file}'") |
| 89 | + |
| 90 | + def save_to( |
| 91 | + self, |
| 92 | + path: str, |
| 93 | + format: VideoContainer = VideoContainer.AUTO, |
| 94 | + codec: VideoCodec = VideoCodec.AUTO, |
| 95 | + metadata: Optional[dict] = None |
| 96 | + ): |
| 97 | + if isinstance(self.__file, io.BytesIO): |
| 98 | + self.__file.seek(0) # Reset the BytesIO object to the beginning |
| 99 | + with av.open(self.__file, mode='r') as container: |
| 100 | + container_format = container.format.name |
| 101 | + video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None |
| 102 | + reuse_streams = True |
| 103 | + if format != VideoContainer.AUTO and format not in container_format.split(","): |
| 104 | + reuse_streams = False |
| 105 | + if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: |
| 106 | + reuse_streams = False |
| 107 | + |
| 108 | + if not reuse_streams: |
| 109 | + components = self.get_components_internal(container) |
| 110 | + video = VideoFromComponents(components) |
| 111 | + return video.save_to( |
| 112 | + path, |
| 113 | + format=format, |
| 114 | + codec=codec, |
| 115 | + metadata=metadata |
| 116 | + ) |
| 117 | + |
| 118 | + streams = container.streams |
| 119 | + with av.open(path, mode='w', options={"movflags": "use_metadata_tags"}) as output_container: |
| 120 | + # Copy over the original metadata |
| 121 | + for key, value in container.metadata.items(): |
| 122 | + if metadata is None or key not in metadata: |
| 123 | + output_container.metadata[key] = value |
| 124 | + |
| 125 | + # Add our new metadata |
| 126 | + if metadata is not None: |
| 127 | + for key, value in metadata.items(): |
| 128 | + if isinstance(value, str): |
| 129 | + output_container.metadata[key] = value |
| 130 | + else: |
| 131 | + output_container.metadata[key] = json.dumps(value) |
| 132 | + |
| 133 | + # Add streams to the new container |
| 134 | + stream_map = {} |
| 135 | + for stream in streams: |
| 136 | + if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)): |
| 137 | + out_stream = output_container.add_stream_from_template(template=stream, opaque=True) |
| 138 | + stream_map[stream] = out_stream |
| 139 | + |
| 140 | + # Write packets to the new container |
| 141 | + for packet in container.demux(): |
| 142 | + if packet.stream in stream_map and packet.dts is not None: |
| 143 | + packet.stream = stream_map[packet.stream] |
| 144 | + output_container.mux(packet) |
| 145 | + |
| 146 | +class VideoFromComponents(VideoInput): |
| 147 | + """ |
| 148 | + Class representing video input from tensors. |
| 149 | + """ |
| 150 | + |
| 151 | + def __init__(self, components: VideoComponents): |
| 152 | + self.__components = components |
| 153 | + |
| 154 | + def get_components(self) -> VideoComponents: |
| 155 | + return VideoComponents( |
| 156 | + images=self.__components.images, |
| 157 | + audio=self.__components.audio, |
| 158 | + frame_rate=self.__components.frame_rate |
| 159 | + ) |
| 160 | + |
| 161 | + def save_to( |
| 162 | + self, |
| 163 | + path: str, |
| 164 | + format: VideoContainer = VideoContainer.AUTO, |
| 165 | + codec: VideoCodec = VideoCodec.AUTO, |
| 166 | + metadata: Optional[dict] = None |
| 167 | + ): |
| 168 | + if format != VideoContainer.AUTO and format != VideoContainer.MP4: |
| 169 | + raise ValueError("Only MP4 format is supported for now") |
| 170 | + if codec != VideoCodec.AUTO and codec != VideoCodec.H264: |
| 171 | + raise ValueError("Only H264 codec is supported for now") |
| 172 | + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: |
| 173 | + # Add metadata before writing any streams |
| 174 | + if metadata is not None: |
| 175 | + for key, value in metadata.items(): |
| 176 | + output.metadata[key] = json.dumps(value) |
| 177 | + |
| 178 | + frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) |
| 179 | + # Create a video stream |
| 180 | + video_stream = output.add_stream('h264', rate=frame_rate) |
| 181 | + video_stream.width = self.__components.images.shape[2] |
| 182 | + video_stream.height = self.__components.images.shape[1] |
| 183 | + video_stream.pix_fmt = 'yuv420p' |
| 184 | + |
| 185 | + # Create an audio stream |
| 186 | + audio_sample_rate = 1 |
| 187 | + audio_stream: Optional[av.AudioStream] = None |
| 188 | + if self.__components.audio: |
| 189 | + audio_sample_rate = int(self.__components.audio['sample_rate']) |
| 190 | + audio_stream = output.add_stream('aac', rate=audio_sample_rate) |
| 191 | + audio_stream.sample_rate = audio_sample_rate |
| 192 | + audio_stream.format = 'fltp' |
| 193 | + |
| 194 | + # Encode video |
| 195 | + for i, frame in enumerate(self.__components.images): |
| 196 | + img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) |
| 197 | + frame = av.VideoFrame.from_ndarray(img, format='rgb24') |
| 198 | + frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 |
| 199 | + packet = video_stream.encode(frame) |
| 200 | + output.mux(packet) |
| 201 | + |
| 202 | + # Flush video |
| 203 | + packet = video_stream.encode(None) |
| 204 | + output.mux(packet) |
| 205 | + |
| 206 | + if audio_stream and self.__components.audio: |
| 207 | + # Encode audio |
| 208 | + samples_per_frame = int(audio_sample_rate / frame_rate) |
| 209 | + num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame |
| 210 | + for i in range(num_frames): |
| 211 | + start = i * samples_per_frame |
| 212 | + end = start + samples_per_frame |
| 213 | + # TODO(Feature) - Add support for stereo audio |
| 214 | + chunk = self.__components.audio['waveform'][0, 0, start:end].unsqueeze(0).numpy() |
| 215 | + audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono') |
| 216 | + audio_frame.sample_rate = audio_sample_rate |
| 217 | + audio_frame.pts = i * samples_per_frame |
| 218 | + for packet in audio_stream.encode(audio_frame): |
| 219 | + output.mux(packet) |
| 220 | + |
| 221 | + # Flush audio |
| 222 | + for packet in audio_stream.encode(None): |
| 223 | + output.mux(packet) |
| 224 | + |
0 commit comments