Skip to content

Commit 4ccbd27

Browse files
authored
bugfix: make streaming spans last for the entire duration of the stream (#120)
Why === The streaming spans were immediately finishing because we were only tracing the construction of the `AsyncIterator` and not the full iteration. What changed ============ - loop and yield each message of the streaming procedures so the span doesn't finish until the async iterator is disposed. Test plan ========= - Streaming procedure spans should now last the full duration of the stream procedure
1 parent dea76d6 commit 4ccbd27

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed

replit_river/client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,15 @@ async def send_subscription(
106106
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
107107
with _trace_procedure("subscription", service_name, procedure_name):
108108
session = await self._transport.get_or_create_session()
109-
return session.send_subscription(
109+
async for msg in session.send_subscription(
110110
service_name,
111111
procedure_name,
112112
request,
113113
request_serializer,
114114
response_deserializer,
115115
error_deserializer,
116-
)
116+
):
117+
yield msg
117118

118119
async def send_stream(
119120
self,
@@ -128,7 +129,7 @@ async def send_stream(
128129
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
129130
with _trace_procedure("stream", service_name, procedure_name):
130131
session = await self._transport.get_or_create_session()
131-
return session.send_stream(
132+
async for msg in session.send_stream(
132133
service_name,
133134
procedure_name,
134135
init,
@@ -137,7 +138,8 @@ async def send_stream(
137138
request_serializer,
138139
response_deserializer,
139140
error_deserializer,
140-
)
141+
):
142+
yield msg
141143

142144

143145
@contextmanager

replit_river/codegen/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ async def {name}(
953953
self,
954954
input: {render_type_expr(input_type)},
955955
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
956-
return await self.client.send_subscription(
956+
return self.client.send_subscription(
957957
{repr(schema_name)},
958958
{repr(name)},
959959
input,
@@ -1029,7 +1029,7 @@ async def {name}(
10291029
init: {render_type_expr(init_type)},
10301030
inputStream: AsyncIterable[{render_type_expr(input_type)}],
10311031
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
1032-
return await self.client.send_stream(
1032+
return self.client.send_stream(
10331033
{repr(schema_name)},
10341034
{repr(name)},
10351035
init,
@@ -1053,7 +1053,7 @@ async def {name}(
10531053
self,
10541054
inputStream: AsyncIterable[{render_type_expr(input_type)}],
10551055
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
1056-
return await self.client.send_stream(
1056+
return self.client.send_stream(
10571057
{repr(schema_name)},
10581058
{repr(name)},
10591059
None,

tests/test_communication.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ async def upload_data() -> AsyncGenerator[str, None]:
3737
serialize_request,
3838
serialize_request,
3939
deserialize_response,
40-
deserialize_response,
41-
) # type: ignore
40+
deserialize_error,
41+
)
4242
assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3"
4343

4444

@@ -58,8 +58,8 @@ async def upload_data() -> AsyncGenerator[str, None]:
5858
serialize_request,
5959
serialize_request,
6060
deserialize_response,
61-
deserialize_response,
62-
) # type: ignore
61+
deserialize_error,
62+
)
6363
assert response == "Uploaded: Initial Data" + (", Data" * iterations)
6464

6565

@@ -77,21 +77,22 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
7777
None,
7878
serialize_request,
7979
deserialize_response,
80-
deserialize_response,
81-
) # type: ignore
80+
deserialize_error,
81+
)
8282
assert response == "Uploaded: "
8383

8484

8585
@pytest.mark.asyncio
8686
async def test_subscription_method(client: Client) -> None:
87-
async for response in await client.send_subscription(
87+
async for response in client.send_subscription(
8888
"test_service",
8989
"subscription_method",
9090
"Bob",
9191
serialize_request,
9292
deserialize_response,
9393
deserialize_error,
9494
):
95+
assert isinstance(response, str)
9596
assert "Subscription message" in response
9697

9798

@@ -103,7 +104,7 @@ async def stream_data() -> AsyncGenerator[str, None]:
103104
yield "Stream 3"
104105

105106
responses = []
106-
async for response in await client.send_stream(
107+
async for response in client.send_stream(
107108
"test_service",
108109
"stream_method",
109110
"Initial Stream Data",
@@ -130,7 +131,7 @@ async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]:
130131
yield "unreachable"
131132

132133
responses = []
133-
async for response in await client.send_stream(
134+
async for response in client.send_stream(
134135
"test_service",
135136
"stream_method",
136137
None,
@@ -167,7 +168,7 @@ async def stream_data() -> AsyncGenerator[str, None]:
167168
deserialize_error,
168169
)
169170
)
170-
stream_task = await client.send_stream(
171+
stream_task = client.send_stream(
171172
"test_service",
172173
"stream_method",
173174
"Initial Stream Data",

0 commit comments

Comments
 (0)