Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ dependencies = [
"protobuf>=5.28.3",
"pydantic-core>=2.20.1",
"websockets>=12.0",
"opentelemetry-sdk>=1.28.2",
"opentelemetry-api>=1.28.2",
]

[tool.uv]
Expand Down
109 changes: 68 additions & 41 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from typing import Any, Generic, Optional, Union
from contextlib import contextmanager
from typing import Any, Generator, Generic, Literal, Optional, Union

from opentelemetry import trace

from replit_river.client_transport import ClientTransport
from replit_river.error_schema import RiverException
from replit_river.transport_options import (
HandshakeMetadataType,
TransportOptions,
Expand All @@ -17,6 +21,7 @@
)

logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)


class Client(Generic[HandshakeMetadataType]):
Expand Down Expand Up @@ -55,15 +60,16 @@ async def send_rpc(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
session = await self._transport.get_or_create_session()
return await session.send_rpc(
service_name,
procedure_name,
request,
request_serializer,
response_deserializer,
error_deserializer,
)
with _trace_procedure("rpc", service_name, procedure_name):
session = await self._transport.get_or_create_session()
return await session.send_rpc(
service_name,
procedure_name,
request,
request_serializer,
response_deserializer,
error_deserializer,
)

async def send_upload(
self,
Expand All @@ -76,17 +82,18 @@ async def send_upload(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
session = await self._transport.get_or_create_session()
return await session.send_upload(
service_name,
procedure_name,
init,
request,
init_serializer,
request_serializer,
response_deserializer,
error_deserializer,
)
with _trace_procedure("upload", service_name, procedure_name):
session = await self._transport.get_or_create_session()
return await session.send_upload(
service_name,
procedure_name,
init,
request,
init_serializer,
request_serializer,
response_deserializer,
error_deserializer,
)

async def send_subscription(
self,
Expand All @@ -97,15 +104,16 @@ async def send_subscription(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
session = await self._transport.get_or_create_session()
return session.send_subscription(
service_name,
procedure_name,
request,
request_serializer,
response_deserializer,
error_deserializer,
)
with _trace_procedure("subscription", service_name, procedure_name):
session = await self._transport.get_or_create_session()
return session.send_subscription(
service_name,
procedure_name,
request,
request_serializer,
response_deserializer,
error_deserializer,
)

async def send_stream(
self,
Expand All @@ -118,14 +126,33 @@ async def send_stream(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
session = await self._transport.get_or_create_session()
return session.send_stream(
service_name,
procedure_name,
init,
request,
init_serializer,
request_serializer,
response_deserializer,
error_deserializer,
)
with _trace_procedure("stream", service_name, procedure_name):
session = await self._transport.get_or_create_session()
return session.send_stream(
service_name,
procedure_name,
init,
request,
init_serializer,
request_serializer,
response_deserializer,
error_deserializer,
)


@contextmanager
def _trace_procedure(
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
service_name: str,
procedure_name: str,
) -> Generator[None, None, None]:
with tracer.start_as_current_span(
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
kind=trace.SpanKind.CLIENT,
) as span:
try:
yield
except RiverException as e:
span.set_attribute("river.error_code", e.code)
span.set_attribute("river.error_message", e.message)
raise e
24 changes: 24 additions & 0 deletions replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import grpc
from aiochannel import Channel, ChannelClosed
from opentelemetry.propagators.textmap import Setter
from pydantic import BaseModel, ConfigDict, Field

from replit_river.error_schema import (
Expand Down Expand Up @@ -86,6 +87,11 @@ class ControlMessageHandshakeResponse(BaseModel):
status: HandShakeStatus


class PropagationContext(BaseModel):
traceparent: Optional[str] = None
tracestate: Optional[str] = None


class TransportMessage(BaseModel):
id: str
# from_ is used instead of from because from is a reserved keyword in Python
Expand All @@ -97,12 +103,30 @@ class TransportMessage(BaseModel):
procedureName: Optional[str] = None
streamId: str
controlFlags: int
tracing: Optional[PropagationContext] = None
payload: Any
model_config = ConfigDict(populate_by_name=True)
# need this because we create TransportMessage objects with destructuring
# where the key is "from"


class TransportMessageTracingSetter(Setter[TransportMessage]):
"""
Handles propagating tracing context to the recipient of the message.
"""

def set(self, carrier: TransportMessage, key: str, value: str) -> None:
if not carrier.tracing:
carrier.tracing = PropagationContext()
match key:
case "traceparent":
carrier.tracing.traceparent = value
case "tracestate":
carrier.tracing.tracestate = value
case _:
logger.warning("unknown trace propagation key", extra={"key": key})


class GrpcContext(grpc.aio.ServicerContext):
"""Represents a gRPC-compatible ServicerContext for River interop."""

Expand Down
5 changes: 5 additions & 0 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import nanoid # type: ignore
import websockets
from aiochannel import Channel, ChannelClosed
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from websockets.exceptions import ConnectionClosed

from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
Expand All @@ -31,6 +32,7 @@
STREAM_OPEN_BIT,
GenericRpcHandler,
TransportMessage,
TransportMessageTracingSetter,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -380,6 +382,9 @@ async def send_message(
serviceName=service_name,
procedureName=procedure_name,
)
TraceContextTextMapPropagator().inject(
msg, None, TransportMessageTracingSetter()
)
try:
# We need this lock to ensure the buffer order and message sending order
# are the same.
Expand Down
2 changes: 1 addition & 1 deletion scripts/lint.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash

set -ex

Expand Down
Loading
Loading