Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/google/adk/flows/llm_flows/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,29 @@ async def _add_instructions_to_user_content(
llm_request: The LLM request to modify
instruction_contents: List of instruction-related contents to insert
"""

def is_valid_instruction_position(
llm_request: LlmRequest, index: int
) -> bool:
"""Checks if instructions can be inserted after a given index.

A valid insertion point is after a model response that is not a tool call.
This prevents injecting instructions in the middle of a user's turn or a
tool-use sequence.

Args:
llm_request: The LLM request containing the conversation contents.
index: The index of the content to check.

Returns:
True if the position after this index is a valid insertion point.
"""
content_at_index = llm_request.contents[index]
is_user_message = content_at_index.role == 'user'
is_tool_request = any(part.function_call for part in content_at_index.parts)

return not is_user_message and not is_tool_request

if not instruction_contents:
return

Expand All @@ -695,7 +718,7 @@ async def _add_instructions_to_user_content(

if llm_request.contents:
for i in range(len(llm_request.contents) - 1, -1, -1):
if llm_request.contents[i].role != 'user':
if is_valid_instruction_position(llm_request, i):
insert_index = i + 1
break
elif i == 0:
Expand Down
86 changes: 86 additions & 0 deletions tests/unittests/flows/llm_flows/test_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,89 @@ async def test_events_with_empty_content_are_skipped():
types.UserContent("Hello"),
types.UserContent("How are you?"),
]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"initial_contents, expected_insertion_index",
[
(
[
types.UserContent("First user message"),
types.ModelContent("Model response"),
types.ModelContent([
types.Part(
function_call=types.FunctionCall(
name="test_tool", args={}
)
)
]),
types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
name="test_tool", response={}
)
)
],
role="user",
),
types.UserContent("Final user message"),
],
2,
),
(
[
types.UserContent("First user message"),
types.UserContent("Second user message"),
types.ModelContent("Model response"),
types.UserContent("Third user message"),
types.UserContent("Fourth user message"),
],
3,
),
(
[
types.UserContent("First user message"),
types.UserContent("Second user message"),
],
0,
),
([], 0),
(
[
types.UserContent("User message"),
types.ModelContent("Model response"),
],
2,
),
],
ids=[
"skips_function_call_and_user_content",
"skips_trailing_user_content",
"inserts_at_start_when_all_user_content",
"inserts_at_start_for_empty_content",
"inserts_at_end_when_last_is_model_content",
],
)
async def test_add_instructions_to_user_content(
initial_contents, expected_insertion_index
):
"""Tests that instructions are correctly inserted into the content list."""
agent = Agent(model="gemini-2.5-flash", name="test_agent")
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
instruction_contents = [
types.Content(parts=[types.Part(text="System instruction")], role="user")
]
llm_request = LlmRequest(model="gemini-2.5-flash", contents=initial_contents)

await contents._add_instructions_to_user_content(
invocation_context, llm_request, instruction_contents
)

assert len(llm_request.contents) == len(initial_contents) + 1
assert (
llm_request.contents[expected_insertion_index] == instruction_contents[0]
)