Skip to content

Commit 31bb49f

Browse files
levilentzdeclan-scale
authored andcommitted
(feat): address PR feedback
1 parent c9fcf8c commit 31bb49f

4 files changed

Lines changed: 36 additions & 145 deletions

File tree

src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ async def on_span_start(self, span: Span) -> None:
145145
items=[sgp_span.to_request_params()]
146146
)
147147

148-
# Input has been serialized and sent; clear it on the retained span to
149-
# release memory. on_span_end only needs output/metadata/end_time.
150-
sgp_span.input = None # type: ignore[assignment]
151148
self._spans[span.id] = sgp_span
152149

153150
@override
@@ -158,6 +155,7 @@ async def on_span_end(self, span: Span) -> None:
158155
return
159156

160157
self._add_source_to_span(span)
158+
sgp_span.input = span.input # type: ignore[assignment]
161159
sgp_span.output = span.output # type: ignore[assignment]
162160
sgp_span.metadata = span.data # type: ignore[assignment]
163161
sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr]

src/agentex/lib/core/tracing/trace.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def end_span(
109109
if span.end_time is None:
110110
span.end_time = datetime.now(UTC)
111111

112-
# input was already serialized at start_span; skip redundant re-serialization
112+
span.input = recursive_model_dump(span.input) if span.input else None
113113
span.output = recursive_model_dump(span.output) if span.output else None
114114
span.data = recursive_model_dump(span.data) if span.data else None
115115

@@ -252,17 +252,12 @@ async def end_span(
252252
if span.end_time is None:
253253
span.end_time = datetime.now(UTC)
254254

255-
# input was already serialized at start_span; skip redundant re-serialization
255+
span.input = recursive_model_dump(span.input) if span.input else None
256256
span.output = recursive_model_dump(span.output) if span.output else None
257257
span.data = recursive_model_dump(span.data) if span.data else None
258258

259259
if self.processors:
260-
end_copy = span.model_copy(deep=True)
261-
# input was already sent with the START event; drop it from the END
262-
# copy to avoid retaining large payloads (system prompts, full
263-
# conversation histories) in the async queue.
264-
end_copy.input = None
265-
self._span_queue.enqueue(SpanEventType.END, end_copy, self.processors)
260+
self._span_queue.enqueue(SpanEventType.END, span.model_copy(deep=True), self.processors)
266261

267262
return span
268263

tests/lib/core/tracing/processors/test_sgp_tracing_processor.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,28 @@ async def test_span_end_for_unknown_span_is_noop(self):
163163

164164
assert len(processor._spans) == 0
165165

166-
async def test_sgp_span_input_cleared_after_start(self):
167-
"""After on_span_start sends the data, sgp_span.input should be None to release memory."""
166+
async def test_sgp_span_input_updated_on_end(self):
167+
"""on_span_end should update sgp_span.input from the incoming span."""
168168
processor, _ = self._make_processor()
169169

170170
with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()):
171171
span = _make_span()
172-
span.input = {"system_prompt": "x" * 10_000}
172+
span.input = {"messages": [{"role": "user", "content": "hello"}]}
173173
await processor.on_span_start(span)
174174

175175
assert len(processor._spans) == 1
176-
sgp_span = next(iter(processor._spans.values()))
177-
assert sgp_span.input is None, (
178-
"SGP span input should be cleared after upsert to release memory"
179-
)
176+
177+
# Simulate modified input at end time
178+
updated_input = {"messages": [
179+
{"role": "user", "content": "hello"},
180+
{"role": "assistant", "content": "hi"},
181+
]}
182+
span.input = updated_input
183+
span.output = {"response": "hi"}
184+
span.end_time = datetime.now(UTC)
185+
await processor.on_span_end(span)
186+
187+
# Span should be removed after end
188+
assert len(processor._spans) == 0
189+
# The end upsert should have been called
190+
assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end

tests/lib/core/tracing/test_span_queue.py

Lines changed: 14 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from unittest.mock import AsyncMock, MagicMock, patch
88

99
from agentex.types.span import Span
10-
from agentex.lib.core.tracing.span_queue import SpanEventType, AsyncSpanQueue, _SpanQueueItem
10+
from agentex.lib.core.tracing.span_queue import SpanEventType, AsyncSpanQueue
1111

1212

1313
def _make_span(span_id: str | None = None) -> Span:
@@ -260,8 +260,8 @@ async def record_end(span: Span) -> None:
260260
# Same span ID for both events
261261
assert call_log[0][1] == call_log[1][1]
262262

263-
async def test_end_event_drops_input(self):
264-
"""END event should NOT carry span.input — it was already sent at START."""
263+
async def test_end_event_preserves_modified_input(self):
264+
"""END event should carry span.input so modifications after start are preserved."""
265265
start_spans: list[Span] = []
266266
end_spans: list[Span] = []
267267

@@ -287,139 +287,26 @@ async def capture_end(span: Span) -> None:
287287
span_queue=queue,
288288
)
289289

290-
large_input = {"system_prompt": "x" * 10_000, "messages": [{"role": "user", "content": "hi"}]}
291-
async with trace.span("llm-call", input=large_input) as span:
292-
span.output = {"response": "hello"}
290+
initial_input = {"messages": [{"role": "user", "content": "hello"}]}
291+
async with trace.span("llm-call", input=initial_input) as span:
292+
# Simulate modifying input after start (e.g. chatbot appending messages)
293+
span.input["messages"].append({"role": "assistant", "content": "hi there"})
294+
span.input["messages"].append({"role": "user", "content": "how are you?"})
295+
span.output = {"response": "I'm good!"}
293296

294297
await queue.shutdown()
295298

296299
assert len(start_spans) == 1
297300
assert len(end_spans) == 1
298301

299-
# START should carry the full input
302+
# START should carry the original input (serialized at start time)
300303
assert start_spans[0].input is not None
301-
assert start_spans[0].input["system_prompt"] == "x" * 10_000
304+
assert len(start_spans[0].input["messages"]) == 1 # only the original message
302305

303-
# END should have input=None (already sent at START)
304-
assert end_spans[0].input is None
306+
# END should carry the modified input (re-serialized at end time)
307+
assert end_spans[0].input is not None
308+
assert len(end_spans[0].input["messages"]) == 3 # all three messages
305309

306310
# END should still carry output and end_time
307311
assert end_spans[0].output is not None
308312
assert end_spans[0].end_time is not None
309-
310-
311-
class TestMemoryUsage:
312-
"""Quantify that the fix actually reduces memory held by the tracing pipeline."""
313-
314-
async def test_end_events_use_less_memory_than_start_events(self):
315-
"""
316-
Simulate N concurrent single-shot requests with large system prompts.
317-
Collect what processors receive and measure serialized sizes.
318-
319-
Before the fix, START and END events were the same size (both carried
320-
full input). After the fix, END events should be dramatically smaller.
321-
"""
322-
start_spans: list[Span] = []
323-
end_spans: list[Span] = []
324-
325-
async def collect_start(span: Span) -> None:
326-
start_spans.append(span)
327-
328-
async def collect_end(span: Span) -> None:
329-
end_spans.append(span)
330-
331-
proc = _make_processor(
332-
on_span_start=AsyncMock(side_effect=collect_start),
333-
on_span_end=AsyncMock(side_effect=collect_end),
334-
)
335-
queue = AsyncSpanQueue()
336-
337-
from agentex.lib.core.tracing.trace import AsyncTrace
338-
339-
trace = AsyncTrace(
340-
processors=[proc],
341-
client=MagicMock(),
342-
trace_id="test-trace",
343-
span_queue=queue,
344-
)
345-
346-
n_spans = 50
347-
prompt_size = 100_000 # 100 KB system prompt per span
348-
large_input = {"system_prompt": "x" * prompt_size}
349-
350-
for _ in range(n_spans):
351-
span = await trace.start_span("llm-call", input=large_input)
352-
span.output = {"response": "hello"}
353-
await trace.end_span(span)
354-
355-
await queue.shutdown()
356-
357-
assert len(start_spans) == n_spans
358-
assert len(end_spans) == n_spans
359-
360-
start_bytes = sum(len(s.model_dump_json()) for s in start_spans)
361-
end_bytes = sum(len(s.model_dump_json()) for s in end_spans)
362-
363-
ratio = end_bytes / start_bytes
364-
assert ratio < 0.05, (
365-
f"END events used {ratio:.1%} of START event memory "
366-
f"(start={start_bytes:,}B, end={end_bytes:,}B). "
367-
f"Expected <5% because the ~{prompt_size:,}B input is dropped."
368-
)
369-
370-
async def test_queue_payload_reduction_old_vs_new(self):
371-
"""
372-
Directly compare data volume in the queue under old vs new behavior.
373-
374-
Simulates a backed-up queue (drain can't keep up with request rate)
375-
holding N span lifecycles. Old behavior: both START and END carry
376-
full input. New behavior: END events have input=None.
377-
378-
This mirrors what happens in K8s under concurrent load — items pile up
379-
in the queue, and each one holds a serialized copy of the system prompt.
380-
"""
381-
n_spans = 30
382-
prompt_size = 200_000 # 200 KB system prompt
383-
384-
def _queue_payload_bytes(q: AsyncSpanQueue) -> int:
385-
"""Total serialized bytes of all spans sitting in the queue."""
386-
return sum(len(item.span.model_dump_json()) for item in list(q._queue._queue))
387-
388-
large_input = {"system_prompt": "x" * prompt_size}
389-
390-
# --- OLD behavior: both START and END carry full input ---
391-
old_queue = AsyncSpanQueue()
392-
for _ in range(n_spans):
393-
span = _make_span()
394-
span.input = large_input
395-
span.output = {"response": "ok"}
396-
old_queue._queue.put_nowait(
397-
_SpanQueueItem(SpanEventType.START, span.model_copy(deep=True), [])
398-
)
399-
old_queue._queue.put_nowait(
400-
_SpanQueueItem(SpanEventType.END, span.model_copy(deep=True), [])
401-
)
402-
403-
# --- NEW behavior: END events drop input ---
404-
new_queue = AsyncSpanQueue()
405-
for _ in range(n_spans):
406-
span = _make_span()
407-
span.input = large_input
408-
span.output = {"response": "ok"}
409-
new_queue._queue.put_nowait(
410-
_SpanQueueItem(SpanEventType.START, span.model_copy(deep=True), [])
411-
)
412-
end_copy = span.model_copy(deep=True)
413-
end_copy.input = None
414-
new_queue._queue.put_nowait(
415-
_SpanQueueItem(SpanEventType.END, end_copy, [])
416-
)
417-
418-
old_bytes = _queue_payload_bytes(old_queue)
419-
new_bytes = _queue_payload_bytes(new_queue)
420-
421-
savings_pct = 1.0 - (new_bytes / old_bytes)
422-
assert savings_pct > 0.40, (
423-
f"Expected >40% queue payload reduction, got {savings_pct:.0%} "
424-
f"(old={old_bytes:,}B, new={new_bytes:,}B)"
425-
)

0 commit comments

Comments
 (0)