Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 149 additions & 23 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,143 @@
RTCConfiguration,
RTCIceServer,
MediaStreamTrack,
RTCDataChannel,
)
import threading
import av
from aiortc.rtcrtpsender import RTCRtpSender
from aiortc.codecs import h264
from pipeline import Pipeline
from utils import patch_loop_datagram
from utils import patch_loop_datagram, StreamStats, add_prefix_to_app_routes
import time

logger = logging.getLogger(__name__)
logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING)
logging.getLogger('aiortc.rtcrtpreceiver').setLevel(logging.WARNING)
logging.getLogger("aiortc").setLevel(logging.DEBUG)
logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING)
logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING)


MAX_BITRATE = 2000000
MIN_BITRATE = 2000000


class VideoStreamTrack(MediaStreamTrack):
"""video stream track that processes video frames using a pipeline.

Attributes:
kind (str): The kind of media, which is "video" for this class.
track (MediaStreamTrack): The underlying media stream track.
pipeline (Pipeline): The processing pipeline to apply to each video frame.
"""

kind = "video"

def __init__(self, track: MediaStreamTrack, pipeline):
def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
"""Initialize the VideoStreamTrack.

Args:
track: The underlying media stream track.
pipeline: The processing pipeline to apply to each video frame.
"""
super().__init__()
self.track = track
self.pipeline = pipeline

async def recv(self):
frame = await self.track.recv()
return await self.pipeline(frame)
self._fps_interval_frame_count = 0
self._last_fps_calculation_time = time.monotonic()
self._lock = threading.Lock()
self._fps = 0.0
self._frame_delay = 0.0
self._running = True
self._fps_interval_frame_count = 0
self._last_fps_calculation_time = time.monotonic()
self._stream_start_time = None
self._last_frame_presentation_time = None
self._last_frame_processed_time = None
self._start_fps_thread()
self._start_frame_delay_thread()

def _start_fps_thread(self):
"""Start a separate thread to calculate FPS periodically."""
self._fps_thread = threading.Thread(
target=self._calculate_fps_loop, daemon=True
)
self._fps_thread.start()

def _calculate_fps_loop(self):
"""Loop to calculate FPS periodically."""
while self._running:
time.sleep(1) # Calculate FPS every second.
with self._lock:
current_time = time.monotonic()
time_diff = current_time - self._last_fps_calculation_time
if time_diff > 0:
self._fps = self._fps_interval_frame_count / time_diff

# Reset start_time and frame_count for the next interval.
self._last_fps_calculation_time = current_time
self._fps_interval_frame_count = 0

def _start_frame_delay_thread(self):
"""Start a separate thread to calculate frame delay periodically."""
self._frame_delay_thread = threading.Thread(
target=self._calculate_frame_delay_loop, daemon=True
)
self._frame_delay_thread.start()

def _calculate_frame_delay_loop(self):
"""Loop to calculate frame delay periodically."""
while self._running:
time.sleep(1) # Calculate frame delay every second.
with self._lock:
if self._last_frame_presentation_time is not None:
current_time = time.monotonic()
self._frame_delay = (current_time - self._stream_start_time ) - float(self._last_frame_presentation_time)

def stop(self):
"""Stop the FPS calculation thread."""
self._running = False
self._fps_thread.join()
self._frame_delay_thread.join()

@property
def fps(self) -> float:
"""Get the current output frames per second (FPS).

Returns:
The current output FPS.
"""
with self._lock:
return self._fps

@property
def frame_delay(self) -> float:
"""Get the current frame delay.

Returns:
The current frame delay.
"""
with self._lock:
return self._frame_delay

async def recv(self) -> av.VideoFrame:
"""Receive and process a video frame. Called by the WebRTC library when a frame
is received.

Returns:
The processed video frame.
"""
if self._stream_start_time is None:
self._stream_start_time = time.monotonic()

input_frame = await self.track.recv()
processed_frame = await self.pipeline(input_frame)

# Store frame info for stats.
with self._lock:
self._fps_interval_frame_count += 1
self._last_frame_presentation_time = input_frame.time
self._last_frame_processed_time = time.monotonic()

return processed_frame


def force_codec(pc, sender, forced_codec):
Expand Down Expand Up @@ -119,30 +229,29 @@ async def offer(request):
@pc.on("datachannel")
def on_datachannel(channel):
if channel.label == "control":

@channel.on("message")
async def on_message(message):
try:
params = json.loads(message)

if params.get("type") == "get_nodes":
nodes_info = await pipeline.get_nodes_info()
response = {
"type": "nodes_info",
"nodes": nodes_info
}
response = {"type": "nodes_info", "nodes": nodes_info}
channel.send(json.dumps(response))
elif params.get("type") == "update_prompt":
if "prompt" not in params:
logger.warning("[Control] Missing prompt in update_prompt message")
logger.warning(
"[Control] Missing prompt in update_prompt message"
)
return
pipeline.set_prompt(params["prompt"])
response = {
"type": "prompt_updated",
"success": True
}
response = {"type": "prompt_updated", "success": True}
channel.send(json.dumps(response))
else:
logger.warning("[Server] Invalid message format - missing required fields")
logger.warning(
"[Server] Invalid message format - missing required fields"
)
except json.JSONDecodeError:
logger.error("[Server] Invalid JSON received")
except Exception as e:
Expand All @@ -156,12 +265,17 @@ def on_track(track):
tracks["video"] = videoTrack
sender = pc.addTrack(videoTrack)

# Store video track in app for stats.
stream_id = track.id
request.app["video_tracks"][stream_id] = videoTrack

codec = "video/H264"
force_codec(pc, sender, codec)

@track.on("ended")
async def on_ended():
logger.info(f"{track.kind} track ended")
request.app["video_tracks"].pop(track.id, None)

@pc.on("connectionstatechange")
async def on_connectionstatechange():
Expand Down Expand Up @@ -207,6 +321,7 @@ async def on_startup(app: web.Application):
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
)
app["pcs"] = set()
app["video_tracks"] = {}


async def on_shutdown(app: web.Application):
Expand Down Expand Up @@ -236,8 +351,8 @@ async def on_shutdown(app: web.Application):

logging.basicConfig(
level=args.log_level.upper(),
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%H:%M:%S'
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)

app = web.Application()
Expand All @@ -247,8 +362,19 @@ async def on_shutdown(app: web.Application):
app.on_startup.append(on_startup)
app.on_shutdown.append(on_shutdown)

app.router.add_get("/", health)

# WebRTC signalling and control routes.
app.router.add_post("/offer", offer)
app.router.add_post("/prompt", set_prompt)
app.router.add_get("/", health)

# Add routes for getting stream statistics.
stream_stats = StreamStats(app)
app.router.add_get("/streams/stats", stream_stats.get_stats)
app.router.add_get("/stream/{stream_id}/stats", stream_stats.get_stats_by_id)

# Add hosted platform route prefix.
# NOTE: This ensures that the local and hosted experiences have consistent routes.
add_prefix_to_app_routes(app, "/live")

web.run_app(app, host=args.host, port=int(args.port))
85 changes: 83 additions & 2 deletions server/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Utility functions for the server."""

import asyncio
import random
import types
import logging

from typing import List, Tuple
import json
from aiohttp import web
from aiortc import MediaStreamTrack
from typing import List, Tuple, Any, Dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,3 +52,80 @@ async def create_datagram_endpoint(

loop.create_datagram_endpoint = types.MethodType(create_datagram_endpoint, loop)
loop._patch_done = True


def add_prefix_to_app_routes(app: web.Application, prefix: str):
"""Add a prefix to all routes in the given application.

Args:
app: The web application whose routes will be prefixed.
prefix: The prefix to add to all routes.
"""
prefix = prefix.rstrip("/")
for route in list(app.router.routes()):
new_path = prefix + route.resource.canonical
app.router.add_route(route.method, new_path, route.handler)


class StreamStats:
"""Class to get stream statistics."""

def __init__(self, app: web.Application):
"""Initialize the StreamStats class."""
self._app = app

def get_video_track_stats(self, video_track: MediaStreamTrack) -> Dict[str, Any]:
"""Get statistics for a video track.

Args:
video_track: The VideoStreamTrack instance.

Returns:
A dictionary containing the statistics.
"""
return {
"fps": video_track.fps,
"frame_delay": video_track.frame_delay,
}

async def get_stats(self, _) -> web.Response:
"""Get the current stream statistics for all streams.

Args:
request: The HTTP GET request.

Returns:
The HTTP response containing the statistics.
"""
video_tracks = self._app.get("video_tracks", {})
all_stats = {
stream_id: self.get_video_track_stats(track)
for stream_id, track in video_tracks.items()
}

return web.Response(
content_type="application/json",
text=json.dumps(all_stats),
)

async def get_stats_by_id(self, request: web.Request) -> web.Response:
"""Get the statistics for a specific stream by ID.

Args:
request: The HTTP GET request.

Returns:
The HTTP response containing the statistics.
"""
stream_id = request.match_info.get("stream_id")
video_track = self._app["video_tracks"].get(stream_id)

if video_track:
stats = self.get_video_track_stats(video_track)
else:
stats = {"error": "Stream not found"}

return web.Response(
content_type="application/json",
text=json.dumps(stats),
)