Skip to content

Commit 786aaed

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Support streaming function call arguments in progressive SSE streaming feature
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 837172244
1 parent 73e5687 commit 786aaed

File tree

5 files changed

+492
-3
lines changed

5 files changed

+492
-3
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.adk import Agent
16+
from google.genai import types
17+
18+
19+
def concat_number_and_string(num: int, s: str) -> str:
20+
"""Concatenate a number and a string.
21+
22+
Args:
23+
num: The number to concatenate.
24+
s: The string to concatenate.
25+
26+
Returns:
27+
The concatenated string.
28+
"""
29+
return str(num) + ': ' + s
30+
31+
32+
root_agent = Agent(
33+
model='gemini-3-pro-preview',
34+
name='hello_world_stream_fc_args',
35+
description='Demo agent showcasing streaming function call arguments.',
36+
instruction="""
37+
You are a helpful assistant.
38+
You can use the `concat_number_and_string` tool to concatenate a number and a string.
39+
You should always call the concat_number_and_string tool to concatenate a number and a string.
40+
You should never concatenate on your own.
41+
""",
42+
tools=[
43+
concat_number_and_string,
44+
],
45+
generate_content_config=types.GenerateContentConfig(
46+
automatic_function_calling=types.AutomaticFunctionCallingConfig(
47+
disable=True,
48+
),
49+
tool_config=types.ToolConfig(
50+
function_calling_config=types.FunctionCallingConfig(
51+
stream_function_call_arguments=True,
52+
),
53+
),
54+
),
55+
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ dependencies = [
4141
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
4242
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
4343
"google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service
44-
"google-genai>=1.45.0, <2.0.0", # Google GenAI SDK
44+
"google-genai>=1.51.0, <2.0.0", # Google GenAI SDK
4545
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
4646
"jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation
4747
"mcp>=1.10.0, <2.0.0", # For MCP Toolset

src/google/adk/utils/streaming_utils.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Any
1718
from typing import AsyncGenerator
1819
from typing import Optional
1920

@@ -43,6 +44,12 @@ def __init__(self):
4344
self._current_text_is_thought: Optional[bool] = None
4445
self._finish_reason: Optional[types.FinishReason] = None
4546

47+
# For streaming function call arguments
48+
self._current_fc_name: Optional[str] = None
49+
self._current_fc_args: dict[str, Any] = {}
50+
self._current_fc_id: Optional[str] = None
51+
self._current_thought_signature: Optional[str] = None
52+
4653
def _flush_text_buffer_to_sequence(self):
4754
"""Flush current text buffer to parts sequence.
4855
@@ -61,6 +68,171 @@ def _flush_text_buffer_to_sequence(self):
6168
self._current_text_buffer = ''
6269
self._current_text_is_thought = None
6370

71+
def _get_value_from_partial_arg(
72+
self, partial_arg: types.PartialArg, json_path: str
73+
):
74+
"""Extract value from a partial argument.
75+
76+
Args:
77+
partial_arg: The partial argument object
78+
json_path: JSONPath for this argument
79+
80+
Returns:
81+
Tuple of (value, has_value) where has_value indicates if a value exists
82+
"""
83+
value = None
84+
has_value = False
85+
86+
if partial_arg.string_value is not None:
87+
# For streaming strings, append chunks to existing value
88+
string_chunk = partial_arg.string_value
89+
has_value = True
90+
91+
# Get current value for this path (if any)
92+
path_without_prefix = (
93+
json_path[2:] if json_path.startswith('$.') else json_path
94+
)
95+
path_parts = path_without_prefix.split('.')
96+
97+
# Try to get existing value
98+
existing_value = self._current_fc_args
99+
for part in path_parts:
100+
if isinstance(existing_value, dict) and part in existing_value:
101+
existing_value = existing_value[part]
102+
else:
103+
existing_value = None
104+
break
105+
106+
# Append to existing string or set new value
107+
if isinstance(existing_value, str):
108+
value = existing_value + string_chunk
109+
else:
110+
value = string_chunk
111+
112+
elif partial_arg.number_value is not None:
113+
value = partial_arg.number_value
114+
has_value = True
115+
elif partial_arg.bool_value is not None:
116+
value = partial_arg.bool_value
117+
has_value = True
118+
elif partial_arg.null_value is not None:
119+
value = None
120+
has_value = True
121+
122+
return value, has_value
123+
124+
def _set_value_by_json_path(self, json_path: str, value: Any):
125+
"""Set a value in _current_fc_args using JSONPath notation.
126+
127+
Args:
128+
json_path: JSONPath string like "$.location" or "$.location.latitude"
129+
value: The value to set
130+
"""
131+
# Remove leading "$." from jsonPath
132+
if json_path.startswith('$.'):
133+
path = json_path[2:]
134+
else:
135+
path = json_path
136+
137+
# Split path into components
138+
path_parts = path.split('.')
139+
140+
# Navigate to the correct location and set the value
141+
current = self._current_fc_args
142+
for part in path_parts[:-1]:
143+
if part not in current:
144+
current[part] = {}
145+
current = current[part]
146+
147+
# Set the final value
148+
current[path_parts[-1]] = value
149+
150+
def _flush_function_call_to_sequence(self):
151+
"""Flush current function call to parts sequence.
152+
153+
This creates a complete FunctionCall part from accumulated partial args.
154+
"""
155+
if self._current_fc_name:
156+
# Create function call part with accumulated args
157+
fc_part = types.Part.from_function_call(
158+
name=self._current_fc_name,
159+
args=self._current_fc_args.copy(),
160+
)
161+
162+
# Set the ID if provided (directly on the function_call object)
163+
if self._current_fc_id and fc_part.function_call:
164+
fc_part.function_call.id = self._current_fc_id
165+
166+
# Set thought_signature if provided (on the Part, not FunctionCall)
167+
if self._current_thought_signature:
168+
fc_part.thought_signature = self._current_thought_signature
169+
170+
self._parts_sequence.append(fc_part)
171+
172+
# Reset FC state
173+
self._current_fc_name = None
174+
self._current_fc_args = {}
175+
self._current_fc_id = None
176+
self._current_thought_signature = None
177+
178+
def _process_streaming_function_call(self, fc: types.FunctionCall):
179+
"""Process a streaming function call with partialArgs.
180+
181+
Args:
182+
fc: The function call object with partial_args
183+
"""
184+
# Save function name if present (first chunk)
185+
if fc.name:
186+
self._current_fc_name = fc.name
187+
if fc.id:
188+
self._current_fc_id = fc.id
189+
190+
# Process each partial argument
191+
for partial_arg in getattr(fc, 'partial_args', []):
192+
json_path = partial_arg.json_path
193+
if not json_path:
194+
continue
195+
196+
# Extract value from partial arg
197+
value, has_value = self._get_value_from_partial_arg(
198+
partial_arg, json_path
199+
)
200+
201+
# Set the value using JSONPath (only if a value was provided)
202+
if has_value:
203+
self._set_value_by_json_path(json_path, value)
204+
205+
# Check if function call is complete
206+
fc_will_continue = getattr(fc, 'will_continue', False)
207+
if not fc_will_continue:
208+
# Function call complete, flush it
209+
self._flush_text_buffer_to_sequence()
210+
self._flush_function_call_to_sequence()
211+
212+
def _process_function_call_part(self, part: types.Part):
213+
"""Process a function call part (streaming or non-streaming).
214+
215+
Args:
216+
part: The part containing a function call
217+
"""
218+
fc = part.function_call
219+
220+
# Check if this is a streaming FC (has partialArgs)
221+
if hasattr(fc, 'partial_args') and fc.partial_args:
222+
# Streaming function call arguments
223+
224+
# Save thought_signature from the part (first chunk should have it)
225+
if part.thought_signature and not self._current_thought_signature:
226+
self._current_thought_signature = part.thought_signature
227+
self._process_streaming_function_call(fc)
228+
else:
229+
# Non-streaming function call (standard format with args)
230+
# Skip empty function calls (used as streaming end markers)
231+
if fc.name:
232+
# Flush any buffered text first, then add the FC part
233+
self._flush_text_buffer_to_sequence()
234+
self._parts_sequence.append(part)
235+
64236
async def process_response(
65237
self, response: types.GenerateContentResponse
66238
) -> AsyncGenerator[LlmResponse, None]:
@@ -101,8 +273,12 @@ async def process_response(
101273
if not self._current_text_buffer:
102274
self._current_text_is_thought = part.thought
103275
self._current_text_buffer += part.text
276+
elif part.function_call:
277+
# Process function call (handles both streaming Args and
278+
# non-streaming Args)
279+
self._process_function_call_part(part)
104280
else:
105-
# Non-text part (function_call, bytes, etc.)
281+
# Other non-text parts (bytes, etc.)
106282
# Flush any buffered text first, then add the non-text part
107283
self._flush_text_buffer_to_sequence()
108284
self._parts_sequence.append(part)
@@ -155,8 +331,9 @@ def close(self) -> Optional[LlmResponse]:
155331
if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
156332
# Always generate final aggregated response in progressive mode
157333
if self._response and self._response.candidates:
158-
# Flush any remaining text buffer to complete the sequence
334+
# Flush any remaining buffers to complete the sequence
159335
self._flush_text_buffer_to_sequence()
336+
self._flush_function_call_to_sequence()
160337

161338
# Use the parts sequence which preserves original ordering
162339
final_parts = self._parts_sequence

0 commit comments

Comments
 (0)