Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.
Merged
16 changes: 15 additions & 1 deletion src/dispatch/conversation/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from .models import Conversation, ConversationCreate, ConversationUpdate

log = logging.getLogger(__name__)


def get(*, db_session, conversation_id: int) -> Conversation | None:
"""Gets a conversation by its id."""
Expand All @@ -26,7 +29,18 @@ def get_by_channel_id_ignoring_channel_type(
conversation = conversations.filter(Conversation.thread_id == thread_id).one_or_none()

if not conversation:
conversation = conversations.one_or_none()
# No conversations with that thread_id, check all conversations without thread filter
conversation_count = conversations.count()
if conversation_count > 1:
# this happens when a user posts in the main thread of a triage channel since
# there are multiple cases in the channel with that channel_id
# so we log a warning and return None
log.warning(
f"Multiple conversations found for channel_id: {channel_id}, thread_id: {thread_id}"
)
conversation = None
else:
conversation = conversations.one_or_none()

if conversation:
if channel_id[0] != conversation.channel_id[0]:
Expand Down
158 changes: 158 additions & 0 deletions tests/conversation/test_conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,161 @@ def test_delete(session, conversation):

delete(db_session=session, conversation_id=conversation.id)
assert not get(db_session=session, conversation_id=conversation.id)


def test_get_by_channel_id_ignoring_channel_type_no_conversations(session):
"""Test when no conversations match the channel_id."""
from dispatch.conversation.service import get_by_channel_id_ignoring_channel_type

result = get_by_channel_id_ignoring_channel_type(
db_session=session, channel_id="nonexistent_channel"
)
assert result is None


def test_get_by_channel_id_ignoring_channel_type_single_conversation(session):
"""Test when exactly one conversation matches the channel_id."""
from dispatch.conversation.service import create, get_by_channel_id_ignoring_channel_type
from dispatch.conversation.models import ConversationCreate

# Create a single conversation
conversation_in = ConversationCreate(
channel_id="test_channel",
resource_id="test_resource",
resource_type="test_type",
weblink="https://example.com/",
)
created_conversation = create(db_session=session, conversation_in=conversation_in)

# Test retrieval
result = get_by_channel_id_ignoring_channel_type(db_session=session, channel_id="test_channel")

assert result is not None
assert result.id == created_conversation.id
assert result.channel_id == "test_channel"


def test_get_by_channel_id_ignoring_channel_type_multiple_conversations_warning(session, caplog):
"""Test when multiple conversations match the channel_id with thread_id=None - should return first conversation."""
from dispatch.conversation.service import create, get_by_channel_id_ignoring_channel_type
from dispatch.conversation.models import ConversationCreate

# Create multiple conversations with the same channel_id
conversation_in_1 = ConversationCreate(
channel_id="duplicate_channel",
resource_id="resource_1",
resource_type="test_type",
weblink="https://example1.com/",
)
conversation_in_2 = ConversationCreate(
channel_id="duplicate_channel",
resource_id="resource_2",
resource_type="test_type",
weblink="https://example2.com/",
)

created_1 = create(db_session=session, conversation_in=conversation_in_1)
create(db_session=session, conversation_in=conversation_in_2)

# Test retrieval - should return first conversation when thread_id is None
result = get_by_channel_id_ignoring_channel_type(
db_session=session, channel_id="duplicate_channel"
)

# When thread_id is None, it uses .first() so should return the first conversation
assert result is not None
assert result.id == created_1.id


def test_get_by_channel_id_ignoring_channel_type_multiple_conversations_fallback_warning(
session, caplog
):
"""Test when multiple conversations match in fallback logic - should log warning and return None."""
from dispatch.conversation.service import create, get_by_channel_id_ignoring_channel_type
from dispatch.conversation.models import ConversationCreate

# Create multiple conversations with the same channel_id but no thread_id
conversation_in_1 = ConversationCreate(
channel_id="fallback_channel",
resource_id="resource_1",
resource_type="test_type",
weblink="https://example1.com/",
)
conversation_in_2 = ConversationCreate(
channel_id="fallback_channel",
resource_id="resource_2",
resource_type="test_type",
weblink="https://example2.com/",
)

create(db_session=session, conversation_in=conversation_in_1)
create(db_session=session, conversation_in=conversation_in_2)

# Test retrieval with a thread_id that doesn't exist - should trigger fallback logic
result = get_by_channel_id_ignoring_channel_type(
db_session=session, channel_id="fallback_channel", thread_id="nonexistent_thread"
)

# Should return None and log warning in fallback logic
assert result is None
assert (
"Multiple conversations found for channel_id: fallback_channel, thread_id: nonexistent_thread"
in caplog.text
)


def test_get_by_channel_id_ignoring_channel_type_with_thread_id(session):
"""Test the thread_id matching logic."""
from dispatch.conversation.service import create, get_by_channel_id_ignoring_channel_type
from dispatch.conversation.models import ConversationCreate

# Create conversations with different thread_ids
conversation_in_1 = ConversationCreate(
channel_id="thread_channel",
thread_id="thread_123",
resource_id="resource_1",
resource_type="test_type",
weblink="https://example1.com/",
)
conversation_in_2 = ConversationCreate(
channel_id="thread_channel",
thread_id="thread_456",
resource_id="resource_2",
resource_type="test_type",
weblink="https://example2.com/",
)

created_1 = create(db_session=session, conversation_in=conversation_in_1)
create(db_session=session, conversation_in=conversation_in_2)

# Test retrieval with specific thread_id
result = get_by_channel_id_ignoring_channel_type(
db_session=session, channel_id="thread_channel", thread_id="thread_123"
)

assert result is not None
assert result.id == created_1.id
assert result.thread_id == "thread_123"


def test_get_by_channel_id_ignoring_channel_type_incident_message_fallback(session):
"""Test the incident message fallback logic (no thread_id provided)."""
from dispatch.conversation.service import create, get_by_channel_id_ignoring_channel_type
from dispatch.conversation.models import ConversationCreate

# Create a conversation without thread_id (incident message)
conversation_in = ConversationCreate(
channel_id="incident_channel",
resource_id="incident_resource",
resource_type="test_type",
weblink="https://example.com/",
)
created_conversation = create(db_session=session, conversation_in=conversation_in)

# Test retrieval without thread_id - should use .first()
result = get_by_channel_id_ignoring_channel_type(
db_session=session, channel_id="incident_channel"
)

assert result is not None
assert result.id == created_conversation.id
Loading