Skip to content

Make _serializing_socket_cls used by internal WebsocketRPCEndpoint configurable from PubSubEndpoint #96

@Taiwo-Sh

Description

@Taiwo-Sh

Hello everyone!

First I'd like to say that this project has been great help. I was looking to implement a websocket endpoint that works across multiple instances of an application when I found this library. Took a while to get a hang of things but I later got to understand it.

For my use case, a client connects to the websocket and sends a request to fetch specific data. This request (with a unique channel id attached) is published to a queue which is consumed by another service that processes the request and then responds with the data, which is then published/broadcasted to all clients subscribed to that channel on the pub/sub endpoint. The client which initially made the request is also notified if still connected.

The Problem
All my endpoints respond with a specific response schema and I noticed after looking at the source code that the underlying WebsocketRPCEndpoint used by the PubSubEndpoint instance uses a _serializing_socket_cls to serialize and deserialize RPC messages, so that's most likely the best place to intercept the message and format the response. I then wrote a custom _serializing_socket_cls but there was no way to pass it to the PubSubEndpoint, so I had to override the pubsub_endpoint.endpoint._serializing_socket_cls manually to achieve this.

Here's a sample code of the implementation I had:
The custom socket serializer

class JobFetchWebSocketProxy(SimpleWebSocket):
    data_schema = FetchExternalJobSchema

    def __init__(self, websocket: SimpleWebSocket):
        self.socket = websocket
        self.rabbitmq_connection = getattr(
            self.root.app.state, "rabbitmq_connection", None
        )
        if self.rabbitmq_connection is None:
            raise ValueError("RabbitMQ connection not found in app state")
        self.pubsub_topics = getattr(self.root.state, "pubsub_topics", None)
        self.db_session = getattr(self.root.state, "db_session", None)

    @property
    def root(self) -> WebSocket:
        root = self.socket
        while socket := (
            getattr(root, "socket", None) or getattr(root, "websocket", None)
        ):
            root = socket

        if not isinstance(root, WebSocket):
            raise ValueError("Could not find root `starlette.WebSocket` instance")
        return root

    async def connect(self, uri: str, **connect_kwargs: typing.Any):  # type: ignore
        await self.socket.connect(uri, **connect_kwargs)  # type: ignore

    def serialize(self, msg: pydantic.BaseModel) -> str:
        return pydantic_serialize(msg)

    def deserialize(self, buffer: str) -> typing.Dict[str, typing.Any]:
        data = orjson.loads(buffer)
        if not isinstance(data, dict):
            raise ValueError("Invalid data format, expected a JSON object")
        return data

    async def send(self, msg: pydantic.BaseModel):  # type: ignore
        # Convert the `RpcMessage` to a `response.Schema`
        if isinstance(msg, RpcMessage) and msg.response is not None:
            # Custom response schema applied here
            response_msg = response.Schema.model_validate(msg.response.result)
        else:
            response_msg = msg
        await self.socket.send(self.serialize(response_msg))  # type: ignore

    async def recv(self) -> typing.Dict[str, typing.Any]:  # type: ignore
        msg = await self.socket.recv()  # type: ignore
        if msg is None:
            return {"request": None}

        try:
            msg = self.deserialize(msg)
            # Data schema validated here
            data = self.data_schema.model_validate(msg)
        except pydantic.ValidationError as exc:
            logger.error(f"Failed to validate job fetch data: {exc}")
            await self.send(
                response.Schema(
                    status=response.Status.ERROR,
                    message="Invalid data",
                    detail="Failed to validate job fetch data",
                    errors=[e["msg"] for e in exc.errors()],
                )
            )
            return {"request": None}

        # Convert the data schema to a `RpcMessage`
        rpc_msg = RpcMessage(
            request=RpcRequest(
                method="fetch_job",
                arguments={
                    "db": self.db_session,
                    "job_url": str(data.job_url),
                    "metadata": data.metadata,
                    "rabbitmq_connection": self.rabbitmq_connection,
                    "pubsub_topics": self.pubsub_topics,
                },
            )
        )
        return rpc_msg.model_dump()

The endpoint

job_fetch_pubsub = PubSubEndpoint(
    methods_class=JobFetchRPCMethods,
    broadcaster=settings.REDIS_URL,
    on_connect=[on_job_fetch_socket_connect],  # type: ignore
    on_disconnect=[on_job_fetch_socket_disconnect],  # type: ignore
    ignore_broadcaster_disconnected=False,
)
# Manually overriding the `_serializing_socket_cls `
job_fetch_pubsub.endpoint._serializing_socket_cls = JobFetchWebSocketProxy

@router.websocket("/ws/job-fetch", dependencies=[authe.authentication_required])
async def job_fetch_socket(ws: WebSocket, user: authe.VerifiedUser):
    user_id = str(user.id)
    ws_channel_id = f"milo:core:job_fetch:ws:{user_id}"
    logger.info(
        f"User {user_id!r} is connecting to job fetch WebSocket with client ID: {ws_channel_id!r}",
        extra={"user_id": user_id, "ws_channel_id": ws_channel_id},
    )
    # Let the topics be the same as the channel ID for simplicity
    pubsub_topics = [ws_channel_id]
    ws.state.pubsub_topics = pubsub_topics
    await job_fetch_pubsub.main_loop(
        ws, channel_id=ws_channel_id, pubsub_topics=pubsub_topics
    )

Proposed Improvement/Solution
It'd be a good improvement to have the serializing_socket_cls be passed as and argument on instantiating the PubSubEndpoint and then passed to the WebsocketRPCEndpoint created internally. So we have something like this instead

job_fetch_pubsub = PubSubEndpoint(
    methods_class=JobFetchRPCMethods,
    broadcaster=settings.REDIS_URL,
    on_connect=[on_job_fetch_socket_connect],  # type: ignore
    on_disconnect=[on_job_fetch_socket_disconnect],  # type: ignore
    ignore_broadcaster_disconnected=False,
    serializing_socket_cls=JobFetchWebSocketProxy,
)

I'd be happy to make a PR for this change. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions