Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ import android.media.AudioTrack
* offers a final opportunity to configure these objects, which will remain valid and effective for
* the duration of the current audio session.
*
* @property goAwayHandler A callback that is invoked when the server initiates a disconnect via a
* [LiveServerGoAway] message. This allows the application to handle server-initiated session
* termination gracefully, such as displaying a message to the user or attempting to reconnect.
*
* @property enableInterruptions If enabled, allows the user to speak over or interrupt the model's
* ongoing reply.
*
Expand All @@ -47,6 +51,7 @@ private constructor(
internal val functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)?,
internal val initializationHandler: ((AudioRecord.Builder, AudioTrack.Builder) -> Unit)?,
internal val transcriptHandler: ((Transcription?, Transcription?) -> Unit)?,
internal val goAwayHandler: ((LiveServerGoAway) -> Unit)?,
internal val enableInterruptions: Boolean
) {

Expand All @@ -62,13 +67,16 @@ private constructor(
*
* @property transcriptHandler See [LiveAudioConversationConfig.transcriptHandler].
*
* @property goAwayHandler See [LiveAudioConversationConfig.goAwayHandler].
*
* @property enableInterruptions See [LiveAudioConversationConfig.enableInterruptions].
*/
public class Builder {
@JvmField public var functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null
@JvmField
public var initializationHandler: ((AudioRecord.Builder, AudioTrack.Builder) -> Unit)? = null
@JvmField public var transcriptHandler: ((Transcription?, Transcription?) -> Unit)? = null
@JvmField public var goAwayHandler: ((LiveServerGoAway) -> Unit)? = null
@JvmField public var enableInterruptions: Boolean = false

public fun setFunctionCallHandler(
Expand All @@ -83,6 +91,10 @@ private constructor(
transcriptHandler: ((Transcription?, Transcription?) -> Unit)?
): Builder = apply { this.transcriptHandler = transcriptHandler }

public fun setGoAwayHandler(goAwayHandler: ((LiveServerGoAway) -> Unit)?): Builder = apply {
this.goAwayHandler = goAwayHandler
}

public fun setEnableInterruptions(enableInterruptions: Boolean): Builder = apply {
this.enableInterruptions = enableInterruptions
}
Expand All @@ -93,6 +105,7 @@ private constructor(
functionCallHandler = functionCallHandler,
initializationHandler = initializationHandler,
transcriptHandler = transcriptHandler,
goAwayHandler = goAwayHandler,
enableInterruptions = enableInterruptions
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.firebase.ai.type

import kotlin.time.Duration.Companion.seconds
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable
Expand All @@ -32,6 +33,7 @@ import kotlinx.serialization.json.jsonObject
* @see LiveServerToolCall
* @see LiveServerToolCallCancellation
* @see LiveServerSetupComplete
* @see LiveServerGoAway
*/
@PublicPreviewAPI public interface LiveServerMessage

Expand Down Expand Up @@ -182,6 +184,70 @@ public class LiveServerToolCallCancellation(public val functionIds: List<String>
}
}

/**
* Notification that the server is initiating a disconnect of the session.
*
* This message is sent by the server when it needs to close the connection, typically due to
* session timeout, resource constraints, or other server-side reasons.
*
* When this message is received, the client should gracefully close the [LiveSession] by calling
* [LiveSession.close].
*
* @property timeLeft The time remaining before the connection terminates as a duration string
* (e.g., "57s", "1.5s"). If null, the connection will terminate immediately. Use [parseTimeLeft] to
* convert this to a [kotlin.time.Duration].
*/
@PublicPreviewAPI
public class LiveServerGoAway(public val timeLeft: String?) : LiveServerMessage {
/**
* Parses the [timeLeft] string into a [kotlin.time.Duration].
*
* Supports protobuf Duration format: "57s", "1.5s", "0.001s", etc. Nanoseconds are expressed as
* fractional seconds (e.g., "1.000000001s").
*
* @return The parsed duration, or null if [timeLeft] is null or cannot be parsed.
*/
public fun parseTimeLeft(): kotlin.time.Duration? {
return timeLeft?.let { parseDurationString(it) }
}

@Serializable internal data class Internal(val timeLeft: String? = null)
@Serializable
internal data class InternalWrapper(val goAway: Internal) : InternalLiveServerMessage {
override fun toPublic() = LiveServerGoAway(goAway.timeLeft)
}
}

/**
* Parses a protobuf Duration string (e.g., "57s", "1.5s") into a [kotlin.time.Duration].
*
* According to the protobuf specification, the JSON representation for Duration is a String that
* ends in 's' to indicate seconds, with nanoseconds expressed as fractional seconds.
*
* @param durationString The duration string to parse (must end with 's').
* @return The parsed duration, or null if the string cannot be parsed.
* @see <a href="https://protobuf.dev/reference/protobuf/google.protobuf/#duration">Protobuf
* Duration</a>
*/
private fun parseDurationString(durationString: String): kotlin.time.Duration? {
return try {
val trimmed = durationString.trim()

// Protobuf Duration format: always ends with 's' (seconds)
if (!trimmed.endsWith("s")) {
return null
}

// Remove 's' suffix and parse as double
val secondsStr = trimmed.dropLast(1)
val seconds = secondsStr.toDoubleOrNull() ?: return null

seconds.seconds
} catch (e: Exception) {
null
}
}

@PublicPreviewAPI
@Serializable(LiveServerMessageSerializer::class)
internal sealed interface InternalLiveServerMessage {
Expand All @@ -202,6 +268,7 @@ internal object LiveServerMessageSerializer :
"toolCall" in jsonObject -> LiveServerToolCall.InternalWrapper.serializer()
"toolCallCancellation" in jsonObject ->
LiveServerToolCallCancellation.InternalWrapper.serializer()
"goAway" in jsonObject -> LiveServerGoAway.InternalWrapper.serializer()
else ->
throw SerializationException(
"Unknown LiveServerMessage response type. Keys found: ${jsonObject.keys}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
Expand Down Expand Up @@ -92,6 +93,21 @@ internal constructor(
*/
private var audioScope = CancelledCoroutineScope

/**
* Exception handler for unhandled exceptions in background coroutines.
*
* Logs the exception and attempts to clean up resources to prevent app crashes.
*/
private val exceptionHandler = CoroutineExceptionHandler { _, throwable ->
Log.e(TAG, "Unhandled exception in LiveSession", throwable)
// Clean up resources to prevent resource leaks
try {
stopAudioConversation()
} catch (e: Exception) {
Log.e(TAG, "Error during cleanup in exception handler", e)
}
}

/**
* Playback audio data sent from the model.
*
Expand All @@ -118,7 +134,12 @@ internal constructor(
public suspend fun startAudioConversation(
functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null
) {
startAudioConversation(functionCallHandler, false)
startAudioConversation(
functionCallHandler = functionCallHandler,
transcriptHandler = null,
goAwayHandler = null,
enableInterruptions = false
)
}

/**
Expand All @@ -143,6 +164,7 @@ internal constructor(
startAudioConversation(
functionCallHandler = functionCallHandler,
transcriptHandler = null,
goAwayHandler = null,
enableInterruptions = enableInterruptions
)
}
Expand All @@ -159,6 +181,10 @@ internal constructor(
* transcript. The first [Transcription] object is the input transcription, and the second is the
* output transcription.
*
* @param goAwayHandler A callback function that is invoked when the server initiates a disconnect
* via a [LiveServerGoAway] message. This allows the application to handle server-initiated
* session termination gracefully.
*
* @param enableInterruptions If enabled, allows the user to speak over or interrupt the model's
* ongoing reply.
*
Expand All @@ -169,12 +195,14 @@ internal constructor(
public suspend fun startAudioConversation(
functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null,
transcriptHandler: ((Transcription?, Transcription?) -> Unit)? = null,
goAwayHandler: ((LiveServerGoAway) -> Unit)? = null,
enableInterruptions: Boolean = false,
) {
startAudioConversation(
liveAudioConversationConfig {
this.functionCallHandler = functionCallHandler
this.transcriptHandler = transcriptHandler
this.goAwayHandler = goAwayHandler
this.enableInterruptions = enableInterruptions
}
)
Expand Down Expand Up @@ -209,14 +237,20 @@ internal constructor(
return@catchAsync
}
networkScope =
CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network"))
audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio"))
CoroutineScope(
blockingDispatcher + childJob() + CoroutineName("LiveSession Network") + exceptionHandler
)
audioScope =
CoroutineScope(
audioDispatcher + childJob() + CoroutineName("LiveSession Audio") + exceptionHandler
)
audioHelper = AudioHelper.build(liveAudioConversationConfig.initializationHandler)

recordUserAudio()
processModelResponses(
liveAudioConversationConfig.functionCallHandler,
liveAudioConversationConfig.transcriptHandler
liveAudioConversationConfig.transcriptHandler,
liveAudioConversationConfig.goAwayHandler
)
listenForModelPlayback(liveAudioConversationConfig.enableInterruptions)
}
Expand Down Expand Up @@ -272,9 +306,14 @@ internal constructor(
response
.getOrNull()
?.let {
JSON.decodeFromString<InternalLiveServerMessage>(
it.readBytes().toString(Charsets.UTF_8)
)
try {
JSON.decodeFromString<InternalLiveServerMessage>(
it.readBytes().toString(Charsets.UTF_8)
)
} catch (e: SerializationException) {
Log.w(TAG, "Failed to deserialize server message: ${e.message}")
null // Skip unknown messages instead of crashing
}
}
?.let { emit(it.toPublic()) }
// delay uses a different scheduler in the backend, so it's "stickier" in its
Expand Down Expand Up @@ -481,10 +520,15 @@ internal constructor(
*
* @param functionCallHandler A callback function that is invoked whenever the server receives a
* function call.
* @param transcriptHandler A callback function that is invoked whenever the server receives a
* transcript.
* @param goAwayHandler A callback function that is invoked when the server initiates a
* disconnect.
*/
private fun processModelResponses(
functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)?,
transcriptHandler: ((Transcription?, Transcription?) -> Unit)?
transcriptHandler: ((Transcription?, Transcription?) -> Unit)?,
goAwayHandler: ((LiveServerGoAway) -> Unit)?
) {
receive()
.onEach {
Expand Down Expand Up @@ -532,6 +576,16 @@ internal constructor(
"The model sent LiveServerSetupComplete after the connection was established."
)
}
is LiveServerGoAway -> {
val timeLeftMsg = it.timeLeft?.let { duration -> " (time left: $duration)" } ?: ""
Log.i(TAG, "Server initiated disconnect$timeLeftMsg")

// Notify the application
goAwayHandler?.invoke(it)

// Close the session gracefully
close()
}
}
}
.launchIn(networkScope)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.google.firebase.ai.type.GroundingChunk
import com.google.firebase.ai.type.GroundingMetadata
import com.google.firebase.ai.type.GroundingSupport
import com.google.firebase.ai.type.ImagenReferenceImage
import com.google.firebase.ai.type.LiveServerGoAway
import com.google.firebase.ai.type.ModalityTokenCount
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.Schema
Expand Down Expand Up @@ -593,4 +594,23 @@ internal class SerializationTests {
val actualJson = descriptorToJson(UrlContext.Internal.serializer().descriptor)
expectedJsonAsString shouldEqualJson actualJson.toString()
}

@Test
fun `test LiveServerGoAway serialization as Json`() {
val expectedJsonAsString =
"""
{
"id": "LiveServerGoAway",
"type": "object",
"properties": {
"timeLeft": {
"type": "string"
}
}
}
"""
.trimIndent()
val actualJson = descriptorToJson(LiveServerGoAway.Internal.serializer().descriptor)
expectedJsonAsString shouldEqualJson actualJson.toString()
}
}
Loading