Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import TypedDict
from typing import Union

from google.genai import types
Expand All @@ -35,15 +36,9 @@
from litellm import ChatCompletionAssistantMessage
from litellm import ChatCompletionAssistantToolCall
from litellm import ChatCompletionDeveloperMessage
from litellm import ChatCompletionFileObject
from litellm import ChatCompletionImageObject
from litellm import ChatCompletionImageUrlObject
from litellm import ChatCompletionMessageToolCall
from litellm import ChatCompletionTextObject
from litellm import ChatCompletionToolMessage
from litellm import ChatCompletionUserMessage
from litellm import ChatCompletionVideoObject
from litellm import ChatCompletionVideoUrlObject
from litellm import completion
from litellm import CustomStreamWrapper
from litellm import Function
Expand All @@ -67,6 +62,11 @@
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}


class ChatCompletionFileUrlObject(TypedDict):
file_data: str
format: str


class FunctionChunk(BaseModel):
id: Optional[str]
name: Optional[str]
Expand Down Expand Up @@ -237,12 +237,10 @@ def _get_content(
if part.text:
if len(parts) == 1:
return part.text
content_objects.append(
ChatCompletionTextObject(
type="text",
text=part.text,
)
)
content_objects.append({
"type": "text",
"text": part.text,
})
elif (
part.inline_data
and part.inline_data.data
Expand All @@ -252,33 +250,32 @@ def _get_content(
data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}"

if part.inline_data.mime_type.startswith("image"):
# Extract format from mime type (e.g., "image/png" -> "png")
format_type = part.inline_data.mime_type.split("/")[-1]
content_objects.append(
ChatCompletionImageObject(
type="image_url",
image_url=ChatCompletionImageUrlObject(
url=data_uri, format=format_type
),
)
)
# Use full MIME type (e.g., "image/png") for providers that validate it
format_type = part.inline_data.mime_type
content_objects.append({
"type": "image_url",
"image_url": {"url": data_uri, "format": format_type},
})
elif part.inline_data.mime_type.startswith("video"):
# Extract format from mime type (e.g., "video/mp4" -> "mp4")
format_type = part.inline_data.mime_type.split("/")[-1]
content_objects.append(
ChatCompletionVideoObject(
type="video_url",
video_url=ChatCompletionVideoUrlObject(
url=data_uri, format=format_type
),
)
)
# Use full MIME type (e.g., "video/mp4") for providers that validate it
format_type = part.inline_data.mime_type
content_objects.append({
"type": "video_url",
"video_url": {"url": data_uri, "format": format_type},
})
elif part.inline_data.mime_type.startswith("audio"):
# Use full MIME type (e.g., "audio/mpeg") for providers that validate it
format_type = part.inline_data.mime_type
content_objects.append({
"type": "audio_url",
"audio_url": {"url": data_uri, "format": format_type},
})
elif part.inline_data.mime_type == "application/pdf":
content_objects.append(
ChatCompletionFileObject(
type="file", file={"file_data": data_uri, "format": "pdf"}
)
)
format_type = part.inline_data.mime_type
content_objects.append({
"type": "file",
"file": {"file_data": data_uri, "format": format_type},
})
else:
raise ValueError("LiteLlm(BaseLlm) does not support this content part.")

Expand Down
30 changes: 28 additions & 2 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def test_get_content_image():
content[0]["image_url"]["url"]
== "data:image/png;base64,dGVzdF9pbWFnZV9kYXRh"
)
assert content[0]["image_url"]["format"] == "png"
assert content[0]["image_url"]["format"] == "image/png"


def test_get_content_video():
Expand All @@ -1049,7 +1049,33 @@ def test_get_content_video():
content[0]["video_url"]["url"]
== "data:video/mp4;base64,dGVzdF92aWRlb19kYXRh"
)
assert content[0]["video_url"]["format"] == "mp4"
assert content[0]["video_url"]["format"] == "video/mp4"


def test_get_content_pdf():
parts = [
types.Part.from_bytes(data=b"test_pdf_data", mime_type="application/pdf")
]
content = _get_content(parts)
assert content[0]["type"] == "file"
assert (
content[0]["file"]["file_data"]
== "data:application/pdf;base64,dGVzdF9wZGZfZGF0YQ=="
)
assert content[0]["file"]["format"] == "application/pdf"


def test_get_content_audio():
parts = [
types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg")
]
content = _get_content(parts)
assert content[0]["type"] == "audio_url"
assert (
content[0]["audio_url"]["url"]
== "data:audio/mpeg;base64,dGVzdF9hdWRpb19kYXRh"
)
assert content[0]["audio_url"]["format"] == "audio/mpeg"


def test_to_litellm_role():
Expand Down