Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion broadcaster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, cast
from urllib.parse import urlparse

from pydantic import BaseModel

if TYPE_CHECKING: # pragma: no cover
from broadcaster.backends.base import BroadcastBackend


class Event:
def __init__(self, channel: str, message: str) -> None:
def __init__(self, channel: str, message: str | BaseModel) -> None:
self.channel = channel
self.message = message

Expand Down Expand Up @@ -43,6 +45,11 @@ def _create_backend(self, url: str) -> BroadcastBackend:

return RedisStreamBackend(url)

elif parsed_url.scheme == "redis-pydantic-stream":
from broadcaster.backends.redis import RedisPydanticStreamBackend

return RedisPydanticStreamBackend(url)

elif parsed_url.scheme in ("postgres", "postgresql"):
from broadcaster.backends.postgres import PostgresBackend

Expand Down
61 changes: 61 additions & 0 deletions broadcaster/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import asyncio
import inspect
import sys
import typing

from pydantic import BaseModel
from redis import asyncio as redis

from .._base import Event
Expand Down Expand Up @@ -108,3 +111,61 @@ async def next_published(self) -> Event:
channel=stream.decode("utf-8"),
message=message.get(b"message", b"").decode("utf-8"),
)


class RedisPydanticStreamBackend(RedisStreamBackend):
"""Redis Stream backend for broadcasting messages using Pydantic models."""

def __init__(self, url: str) -> None:
"""Create a new Redis Stream backend."""
url = url.replace("redis-pydantic-stream", "redis", 1)
self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {}
self._ready = asyncio.Event()
self._producer = redis.Redis.from_url(url)
self._consumer = redis.Redis.from_url(url)
self._module_cache: dict[str, type[BaseModel]] = {}

def _build_module_cache(self) -> None:
"""Build a cache of Pydantic models."""
modules = list(sys.modules.keys())
for module_name in modules:
for _, obj in inspect.getmembers(sys.modules[module_name]):
if inspect.isclass(obj) and issubclass(obj, BaseModel):
self._module_cache[obj.__name__] = obj

async def publish(self, channel: str, message: BaseModel) -> None:
"""Publish a message to a channel."""
msg_type: str = message.__class__.__name__

if msg_type not in self._module_cache:
self._module_cache[msg_type] = message.__class__

message_json: str = message.model_dump_json()
await self._producer.xadd(channel, {"msg_type": msg_type, "message": message_json})

async def wait_for_messages(self) -> list[StreamMessageType]:
"""Wait for messages to be published."""
await self._ready.wait()
self._build_module_cache()
messages = None
while not messages:
messages = await self._consumer.xread(self.streams, count=1, block=100)
return messages

async def next_published(self) -> Event:
"""Get the next published message."""
messages = await self.wait_for_messages()
stream, events = messages[0]
_msg_id, message = events[0]
self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8")
msg_type = message.get(b"msg_type", b"").decode("utf-8")
message_data = message.get(b"message", b"").decode("utf-8")
message_obj: BaseModel | None = None
if msg_type in self._module_cache:
message_obj = self._module_cache[msg_type].model_validate_json(message_data)
if not message_obj:
return Event(stream.decode("utf-8"), message_data)
return Event(
channel=stream.decode("utf-8"),
message=message_obj,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
redis = ["redis"]
postgres = ["asyncpg"]
kafka = ["aiokafka"]
pydantic = ["pydantic", "redis"]
test = ["pytest", "pytest-asyncio"]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-e .[redis,postgres,kafka]
-e .[redis,postgres,kafka,pydantic]

# Documentation
mkdocs==1.5.3
Expand Down
26 changes: 26 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
import typing

import pytest
from pydantic import BaseModel

from broadcaster import Broadcast, BroadcastBackend, Event
from broadcaster.backends.kafka import KafkaBackend


class PydanticEvent(BaseModel):
event: str
data: str


class CustomBackend(BroadcastBackend):
def __init__(self, url: str):
self._subscribed: set[str] = set()
Expand Down Expand Up @@ -71,6 +77,26 @@ async def test_redis_stream():
assert event.message == "hello"


@pytest.mark.asyncio
async def test_redis_pydantic_stream():
async with Broadcast("redis-pydantic-stream://localhost:6379") as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
message = PydanticEvent(event="on_message", data="hello")
await broadcast.publish("chatroom", message)
event = await subscriber.get()
assert event.channel == "chatroom"
assert isinstance(event.message, PydanticEvent)
assert event.message.event == message.event
assert event.message.data == message.data
async with broadcast.subscribe("chatroom1") as subscriber:
await broadcast.publish("chatroom1", message)
event = await subscriber.get()
assert event.channel == "chatroom1"
assert isinstance(event.message, PydanticEvent)
assert event.message.event == message.event
assert event.message.data == message.data


@pytest.mark.asyncio
async def test_postgres():
async with Broadcast("postgres://postgres:postgres@localhost:5432/broadcaster") as broadcast:
Expand Down
Loading