Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 49 additions & 24 deletions sdk/voice/speechmatics/voice/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def __init__(
self._session_speakers: dict[str, SessionSpeaker] = {}
self._is_speaking: bool = False
self._current_speaker: Optional[str] = None
self._last_valid_partial_word_count: int = 0
self._dz_enabled: bool = self._config.enable_diarization
self._dz_config = self._config.speaker_config
self._last_speak_start_time: Optional[float] = None
Expand Down Expand Up @@ -454,7 +455,7 @@ def _prepare_config(
)

# Punctuation overrides
if config.punctuation_overrides:
if config.punctuation_overrides is not None:
transcription_config.punctuation_overrides = config.punctuation_overrides

# Configure the audio
Expand Down Expand Up @@ -580,7 +581,7 @@ async def disconnect(self) -> None:
self._closing_session = True

# Emit final segments
await self._emit_segments(finalize=True)
await self._emit_segments(finalize=True, is_eou=True)

# Emit final metrics
self._emit_speaker_metrics()
Expand Down Expand Up @@ -747,7 +748,7 @@ async def emit() -> None:
return

# Emit the segments
self._stt_message_queue.put_nowait(lambda: self._emit_segments(finalize=True))
self._stt_message_queue.put_nowait(lambda: self._emit_segments(finalize=True, is_eou=True))

# Call async task (only if not already waiting for forced EOU)
if not (self._config.end_of_turn_config.use_forced_eou and self._forced_eou_active):
Expand Down Expand Up @@ -1122,8 +1123,7 @@ async def _add_speech_fragments(self, message: dict[str, Any], is_final: bool =
self._last_fragment_end_time = max(self._last_fragment_end_time, fragment.end_time)

# Evaluate for VAD (only done on partials)
if not is_final:
await self._vad_evaluation(fragments)
await self._vad_evaluation(fragments, is_final=is_final)

# Fragments to retain
retained_fragments = [
Expand Down Expand Up @@ -1234,7 +1234,7 @@ async def fn() -> None:
# Emit the segments
await self._emit_segments()

async def _emit_segments(self, finalize: bool = False) -> None:
async def _emit_segments(self, finalize: bool = False, is_eou: bool = False) -> None:
"""Emit segments to listeners.

This function will emit segments in the view without any further checks
Expand All @@ -1243,6 +1243,7 @@ async def _emit_segments(self, finalize: bool = False) -> None:

Args:
finalize: Whether to finalize all segments.
is_eou: Whether the segments are being emitted after an end of utterance.
"""

# Only process if we have segments in the buffer
Expand Down Expand Up @@ -1313,6 +1314,10 @@ async def _emit_segments(self, finalize: bool = False) -> None:
segment=last_segment,
)

# Mark the final segments as end of utterance
if is_eou:
final_segments[-1].is_eou = True

# Emit segments
self._emit_message(
SegmentMessage(
Expand All @@ -1325,6 +1330,7 @@ async def _emit_segments(self, finalize: bool = False) -> None:
language=s.language,
text=s.text,
annotation=s.annotation,
is_eou=s.is_eou,
fragments=(
[SegmentMessageSegmentFragment(**f.__dict__) for f in s.fragments]
if self._config.include_results
Expand Down Expand Up @@ -1698,52 +1704,71 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None:
# VAD (VOICE ACTIVITY DETECTION) / SPEAKER DETECTION
# ============================================================================

async def _vad_evaluation(self, fragments: list[SpeechFragment]) -> None:
async def _vad_evaluation(self, fragments: list[SpeechFragment], is_final: bool) -> None:
"""Emit a VAD event.

This will emit `SPEAKER_STARTED` and `SPEAKER_ENDED` events to the client and is
based on valid transcription for active speakers. Ignored or speakers not in
focus will not be considered an active participant.

This should only run on partial / non-final words.

Args:
fragments: The list of fragments to use for evaluation.
is_final: Whether the fragments are final.
"""

# Find the valid list of partial words
# Filter fragments for valid speakers, if required
if self._dz_enabled and self._dz_config.focus_speakers:
new_partials = [
frag
for frag in fragments
if frag.speaker in self._dz_config.focus_speakers and frag.type_ == "word" and not frag.is_final
]
else:
new_partials = [frag for frag in fragments if frag.type_ == "word" and not frag.is_final]
fragments = [f for f in fragments if f.speaker in self._dz_config.focus_speakers]

# Find partial and final words
words = [f for f in fragments if f.type_ == "word"]

# Check if we have any new words
has_words = len(words) > 0

# Handle finals
if is_final:
"""Check for finals without partials.

When a forced end of utterance is used, the transcription may skip partials
and go straight to finals. In this case, we need to check if we had any partials
last time and if not, we need to assume we have a new speaker.
"""

# Check if transcript went straight to finals (typical with forced end of utterance)
if not self._is_speaking and has_words and self._last_valid_partial_word_count == 0:
# Track the current speaker
self._current_speaker = words[0].speaker
self._is_speaking = True

# Emit speaker started event
await self._handle_speaker_started(self._current_speaker, words[0].start_time)

# No further processing needed
return

# Check if we have new partials
has_valid_partial = len(new_partials) > 0
# Track partial count
self._last_valid_partial_word_count = len(words)

# Current states
current_is_speaking = self._is_speaking
current_speaker = self._current_speaker

# Establish the speaker from latest partials
latest_speaker = new_partials[-1].speaker if has_valid_partial else current_speaker
latest_speaker = words[-1].speaker if has_words else current_speaker

# Determine if the speaker has changed (and we have a speaker)
speaker_changed = latest_speaker != current_speaker and current_speaker is not None

# Start / end times (earliest and latest)
speaker_start_time = new_partials[0].start_time if has_valid_partial else None
speaker_start_time = words[0].start_time if has_words else None
speaker_end_time = self._last_fragment_end_time

# If diarization is enabled, indicate speaker switching
if self._dz_enabled and latest_speaker is not None:
"""When enabled, we send a speech events if the speaker has changed.

This
will emit a SPEAKER_ENDED for the previous speaker and a SPEAKER_STARTED
This will emit a SPEAKER_ENDED for the previous speaker and a SPEAKER_STARTED
for the new speaker.

For any client that wishes to show _which_ speaker is speaking, this will
Expand Down Expand Up @@ -1774,7 +1799,7 @@ async def _vad_evaluation(self, fragments: list[SpeechFragment]) -> None:
self._current_speaker = latest_speaker

# No further processing if we have no new fragments and we are not speaking
if has_valid_partial == current_is_speaking:
if has_words == current_is_speaking:
return

# Update speaking state
Expand Down
5 changes: 5 additions & 0 deletions sdk/voice/speechmatics/voice/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@ class SpeakerSegment(BaseModel):
fragments: The list of SpeechFragment items.
text: The text of the segment.
annotation: The annotation associated with the segment.
is_eou: Whether the fragment is the end of an utterance. Defaults to `False`.
"""

speaker_id: Optional[str] = None
Expand All @@ -949,6 +950,7 @@ class SpeakerSegment(BaseModel):
fragments: list[SpeechFragment] = Field(default_factory=list)
text: Optional[str] = None
annotation: AnnotationResult = Field(default_factory=AnnotationResult)
is_eou: bool = False

model_config = ConfigDict(use_enum_values=True, arbitrary_types_allowed=True)

Expand Down Expand Up @@ -1313,6 +1315,8 @@ class SegmentMessageSegment(BaseModel):
language: The language of the frame.
text: The text of the segment.
fragments: The fragments associated with the segment.
annotation: The annotation associated with the segment (optional).
is_eou: Whether the segment is an end of utterance.
metadata: The metadata associated with the segment.
"""

Expand All @@ -1323,6 +1327,7 @@ class SegmentMessageSegment(BaseModel):
text: Optional[str] = None
fragments: Optional[list[SegmentMessageSegmentFragment]] = None
annotation: list[AnnotationFlags] = Field(default_factory=list, exclude=False)
is_eou: bool = False
metadata: MessageTimeMetadata

model_config = ConfigDict(extra="ignore")
Expand Down