Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions livekit-rtc/livekit/rtc/audio_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
capacity (int, optional): The maximum number of mixed frames to store in the output queue.
Defaults to 100.
"""
self._streams: set[_Stream] = set()
self._streams: dict[_Stream, asyncio.Lock] = {}
self._buffers: dict[_Stream, np.ndarray] = {}
self._sample_rate: int = sample_rate
self._num_channels: int = num_channels
Expand All @@ -62,7 +62,7 @@ def __init__(
self._ending: bool = False
self._mixer_task: asyncio.Task = asyncio.create_task(self._mixer())

def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
def add_stream(self, stream: AsyncIterator[AudioFrame]) -> asyncio.Lock:
"""
Add an audio stream to the mixer.

Expand All @@ -71,13 +71,17 @@ def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None:

Args:
stream (AsyncIterator[AudioFrame]): An async iterator that produces AudioFrame objects.

Returns:
asyncio.Lock: A lock that can be used to synchronize access to the stream.
"""
if self._ending:
raise RuntimeError("Cannot add stream after mixer has been closed")

self._streams.add(stream)
self._streams[stream] = asyncio.Lock()
if stream not in self._buffers:
self._buffers[stream] = np.empty((0, self._num_channels), dtype=np.int16)
return self._streams[stream]

def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
"""
Expand All @@ -88,7 +92,7 @@ def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
Args:
stream (AsyncIterator[AudioFrame]): The audio stream to remove.
"""
self._streams.discard(stream)
self._streams.pop(stream, None)
self._buffers.pop(stream, None)

def __aiter__(self) -> "AudioMixer":
Expand Down Expand Up @@ -133,9 +137,10 @@ async def _mixer(self) -> None:
tasks = [
self._get_contribution(
stream,
lock,
self._buffers.get(stream, np.empty((0, self._num_channels), dtype=np.int16)),
)
for stream in list(self._streams)
for stream, lock in self._streams.items()
]
results = await asyncio.gather(*tasks, return_exceptions=True)
contributions = []
Expand Down Expand Up @@ -169,15 +174,18 @@ async def _mixer(self) -> None:
await self._queue.put(None)

async def _get_contribution(
self, stream: AsyncIterator[AudioFrame], buf: np.ndarray
self, stream: AsyncIterator[AudioFrame], lock: asyncio.Lock, buf: np.ndarray
) -> _Contribution:
had_data = buf.shape[0] > 0
exhausted = False

async def _get_frame() -> AudioFrame:
async with lock:
return await stream.__anext__()

while buf.shape[0] < self._chunk_size and not exhausted:
try:
frame = await asyncio.wait_for(
stream.__anext__(), timeout=self._stream_timeout_ms / 1000
)
frame = await asyncio.wait_for(_get_frame(), timeout=self._stream_timeout_ms / 1000)
except asyncio.TimeoutError:
logger.warning(f"AudioMixer: stream {stream} timeout, ignoring")
break
Expand Down
Loading