Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import Any, Generator, Generic, Literal, Optional, Union

from opentelemetry import trace
from opentelemetry.trace import Span, SpanKind, StatusCode

from replit_river.client_transport import ClientTransport
from replit_river.error_schema import RiverException
from replit_river.error_schema import RiverError, RiverException
from replit_river.transport_options import (
HandshakeMetadataType,
TransportOptions,
Expand Down Expand Up @@ -60,7 +61,7 @@ async def send_rpc(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
with _trace_procedure("rpc", service_name, procedure_name):
with _trace_procedure("rpc", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
return await session.send_rpc(
service_name,
Expand All @@ -69,6 +70,7 @@ async def send_rpc(
request_serializer,
response_deserializer,
error_deserializer,
span,
)

async def send_upload(
Expand All @@ -82,7 +84,7 @@ async def send_upload(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
with _trace_procedure("upload", service_name, procedure_name):
with _trace_procedure("upload", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
return await session.send_upload(
service_name,
Expand All @@ -93,6 +95,7 @@ async def send_upload(
request_serializer,
response_deserializer,
error_deserializer,
span,
)

async def send_subscription(
Expand All @@ -104,7 +107,7 @@ async def send_subscription(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
with _trace_procedure("subscription", service_name, procedure_name):
with _trace_procedure("subscription", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
async for msg in session.send_subscription(
service_name,
Expand All @@ -113,8 +116,11 @@ async def send_subscription(
request_serializer,
response_deserializer,
error_deserializer,
span,
):
yield msg
if isinstance(msg, RiverError):
_record_river_error(span, msg)
yield msg # type: ignore # https://github.com/python/mypy/issues/10817

async def send_stream(
self,
Expand All @@ -127,7 +133,7 @@ async def send_stream(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
with _trace_procedure("stream", service_name, procedure_name):
with _trace_procedure("stream", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
async for msg in session.send_stream(
service_name,
Expand All @@ -138,23 +144,32 @@ async def send_stream(
request_serializer,
response_deserializer,
error_deserializer,
span,
):
yield msg
if isinstance(msg, RiverError):
_record_river_error(span, msg)
yield msg # type: ignore # https://github.com/python/mypy/issues/10817


@contextmanager
def _trace_procedure(
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
service_name: str,
procedure_name: str,
) -> Generator[None, None, None]:
with tracer.start_as_current_span(
) -> Generator[Span, None, None]:
with tracer.start_span(
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
kind=trace.SpanKind.CLIENT,
kind=SpanKind.CLIENT,
) as span:
try:
yield
yield span
except RiverException as e:
span.set_attribute("river.error_code", e.code)
span.set_attribute("river.error_message", e.message)
_record_river_error(span, RiverError(code=e.code, message=e.message))
raise e


def _record_river_error(span: Span, error: RiverError) -> None:
span.set_status(StatusCode.ERROR, error.message)
span.record_exception(RiverException(error.code, error.message))
span.set_attribute("river.error_code", error.code)
span.set_attribute("river.error_message", error.message)
11 changes: 11 additions & 0 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import nanoid # type: ignore
from aiochannel import Channel
from aiochannel.errors import ChannelClosed
from opentelemetry.trace import Span

from replit_river.error_schema import (
ERROR_CODE_STREAM_CLOSED,
Expand Down Expand Up @@ -37,6 +38,7 @@ async def send_rpc(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> ResponseType:
"""Sends a single RPC request to the server.

Expand All @@ -51,6 +53,7 @@ async def send_rpc(
payload=request_serializer(request),
service_name=service_name,
procedure_name=procedure_name,
span=span,
)
# Handle potential errors during communication
try:
Expand Down Expand Up @@ -89,6 +92,7 @@ async def send_upload(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> ResponseType:
"""Sends an upload request to the server.

Expand All @@ -107,6 +111,7 @@ async def send_upload(
service_name=service_name,
procedure_name=procedure_name,
payload=init_serializer(init),
span=span,
)
first_message = False
# If this request is not closed and the session is killed, we should
Expand All @@ -122,6 +127,7 @@ async def send_upload(
procedure_name=procedure_name,
control_flags=control_flags,
payload=request_serializer(item),
span=span,
)
except Exception as e:
raise RiverServiceException(
Expand Down Expand Up @@ -171,6 +177,7 @@ async def send_subscription(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
"""Sends a subscription request to the server.

Expand All @@ -185,6 +192,7 @@ async def send_subscription(
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT,
payload=request_serializer(request),
span=span,
)

# Handle potential errors during communication
Expand Down Expand Up @@ -221,6 +229,7 @@ async def send_stream(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
"""Sends a subscription request to the server.

Expand All @@ -239,6 +248,7 @@ async def send_stream(
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT,
payload=init_serializer(init),
span=span,
)
else:
# Get the very first message to open the stream
Expand All @@ -250,6 +260,7 @@ async def send_stream(
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT,
payload=request_serializer(first),
span=span,
)

except StopAsyncIteration:
Expand Down
20 changes: 10 additions & 10 deletions replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,9 @@ async def _convert_outputs() -> None:

convert_inputs_task = task_manager.create_task(_convert_inputs())
convert_outputs_task = task_manager.create_task(_convert_outputs())
await asyncio.wait((convert_inputs_task, convert_outputs_task))

done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task))
for task in done:
await task
except Exception as e:
logger.exception("Uncaught exception in upload")
await output.put(
Expand Down Expand Up @@ -440,17 +441,16 @@ async def _convert_inputs() -> None:
response = method(request, context)

async def _convert_outputs() -> None:
try:
async for item in response:
await output.put(
get_response_or_error_payload(item, response_serializer)
)
finally:
output.close()
async for item in response:
await output.put(
get_response_or_error_payload(item, response_serializer)
)

convert_inputs_task = task_manager.create_task(_convert_inputs())
convert_outputs_task = task_manager.create_task(_convert_outputs())
await asyncio.wait((convert_inputs_task, convert_outputs_task))
done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task))
for task in done:
await task
except grpc.RpcError:
logger.exception("RPC exception in stream")
code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN"
Expand Down
11 changes: 8 additions & 3 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import nanoid # type: ignore
import websockets
from aiochannel import Channel, ChannelClosed
from opentelemetry.trace import Span, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from websockets.exceptions import ConnectionClosed

Expand Down Expand Up @@ -37,6 +38,9 @@

logger = logging.getLogger(__name__)

trace_propagator = TraceContextTextMapPropagator()
trace_setter = TransportMessageTracingSetter()


class SessionState(enum.Enum):
"""The state a session can be in.
Expand Down Expand Up @@ -365,6 +369,7 @@ async def send_message(
control_flags: int = 0,
service_name: str | None = None,
procedure_name: str | None = None,
span: Span | None = None,
) -> None:
"""Send serialized messages to the websockets."""
# if the session is not active, we should not do anything
Expand All @@ -382,9 +387,9 @@ async def send_message(
serviceName=service_name,
procedureName=procedure_name,
)
TraceContextTextMapPropagator().inject(
msg, None, TransportMessageTracingSetter()
)
if span:
with use_span(span):
trace_propagator.inject(msg, None, trace_setter)
try:
# We need this lock to ensure the buffer order and message sending order
# are the same.
Expand Down
44 changes: 39 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from collections.abc import AsyncIterator
from typing import Any, AsyncGenerator, Iterator, Literal

import grpc.aio
import nanoid # type: ignore
import pytest
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from websockets.server import serve

from replit_river.client import Client
from replit_river.client_transport import UriAndMetadata
from replit_river.error_schema import RiverError
from replit_river.error_schema import RiverError, RiverException
from replit_river.rpc import (
GrpcContext,
TransportMessage,
rpc_method_handler,
stream_method_handler,
Expand Down Expand Up @@ -68,12 +72,12 @@ def deserialize_error(response: dict) -> RiverError:


# RPC method handlers for testing
async def rpc_handler(request: str, context: GrpcContext) -> str:
async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
return f"Hello, {request}!"


async def subscription_handler(
request: str, context: GrpcContext
request: str, context: grpc.aio.ServicerContext
) -> AsyncGenerator[str, None]:
for i in range(5):
yield f"Subscription message {i} for {request}"
Expand All @@ -93,7 +97,8 @@ async def upload_handler(


async def stream_handler(
request: Iterator[str] | AsyncIterator[str], context: GrpcContext
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
if isinstance(request, AsyncIterator):
async for data in request:
Expand All @@ -103,6 +108,14 @@ async def stream_handler(
yield f"Stream response for {data}"


async def stream_error_handler(
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
raise RiverException("INJECTED_ERROR", "test error")
yield "test" # appease the type checker


@pytest.fixture
def transport_options() -> TransportOptions:
return TransportOptions()
Expand Down Expand Up @@ -137,6 +150,12 @@ def server(transport_options: TransportOptions) -> Server:
stream_handler, deserialize_request, serialize_response
),
),
("test_service", "stream_method_error"): (
"stream",
stream_method_handler(
stream_error_handler, deserialize_request, serialize_response
),
),
}
)
return server
Expand Down Expand Up @@ -173,3 +192,18 @@ async def websocket_uri_factory() -> UriAndMetadata[None]:
await server.close()
# Server should close normally
no_logging_error()


@pytest.fixture(scope="session")
def span_exporter() -> InMemorySpanExporter:
exporter = InMemorySpanExporter()
processor = SimpleSpanProcessor(exporter)
provider = TracerProvider()
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
return exporter


@pytest.fixture(autouse=True)
def reset_span_exporter(span_exporter: InMemorySpanExporter) -> None:
span_exporter.clear()
7 changes: 7 additions & 0 deletions tests/river_fixtures/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ class NoErrors:

def __init__(self, caplog: LogCaptureFixture):
self.caplog = caplog
self._allow_errors = False

def allow_errors(self) -> None:
self._allow_errors = True

def __call__(self) -> None:
if self._allow_errors:
return

assert len(self.caplog.get_records("setup")) == 0
assert len(self.caplog.get_records("call")) == 0
assert len(self.caplog.get_records("teardown")) == 0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def test_rpc_method(client: Client) -> None:
serialize_request,
deserialize_response,
deserialize_error,
) # type: ignore
)
assert response == "Hello, Alice!"


Expand Down
Loading
Loading