Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 110 additions & 28 deletions tests/scenario_tests/test_events_assistant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from time import sleep
import time

from slack_sdk.web import WebClient

Expand All @@ -10,6 +10,13 @@
from tests.utils import remove_os_env_temporarily, restore_os_env


def assert_target_called(called: dict, timeout: float = 0.5):
deadline = time.time() + timeout
while called["value"] is not True and time.time() < deadline:
time.sleep(0.1)
assert called["value"] is True


class TestEventsAssistant:
valid_token = "xoxb-valid"
mock_api_server_base_url = "http://localhost:8888"
Expand All @@ -26,81 +33,156 @@ def teardown_method(self):
cleanup_mock_web_api_server(self)
restore_os_env(self.old_os_env)

def test_assistant_threads(self):
def test_thread_started(self):
app = App(client=self.web_client)
assistant = Assistant()

state = {"called": False}

def assert_target_called():
count = 0
while state["called"] is False and count < 20:
sleep(0.1)
count += 1
assert state["called"] is True
state["called"] = False
called = {"value": False}

@assistant.thread_started
def start_thread(say: Say, set_suggested_prompts: SetSuggestedPrompts, context: BoltContext):
def start_thread(say: Say, set_suggested_prompts: SetSuggestedPrompts, set_status: SetStatus, context: BoltContext):
assert context.channel_id == "D111"
assert context.thread_ts == "1726133698.626339"
assert set_status.thread_ts == context.thread_ts
assert say.thread_ts == context.thread_ts
say("Hi, how can I help you today?")
set_suggested_prompts(prompts=[{"title": "What does SLACK stand for?", "message": "What does SLACK stand for?"}])
set_suggested_prompts(
prompts=[{"title": "What does SLACK stand for?", "message": "What does SLACK stand for?"}], title="foo"
)
state["called"] = True
called["value"] = True

app.assistant(assistant)

request = BoltRequest(body=thread_started_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert_target_called(called)

def test_thread_context_changed(self):
app = App(client=self.web_client)
assistant = Assistant()
called = {"value": False}

@assistant.thread_context_changed
def handle_thread_context_changed(context: BoltContext):
assert context.channel_id == "D111"
assert context.thread_ts == "1726133698.626339"
state["called"] = True
called["value"] = True

app.assistant(assistant)

request = BoltRequest(body=thread_context_changed_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert_target_called(called)

def test_user_message(self):
app = App(client=self.web_client)
assistant = Assistant()
called = {"value": False}

@assistant.user_message
def handle_user_message(say: Say, set_status: SetStatus, context: BoltContext):
assert context.channel_id == "D111"
assert context.thread_ts == "1726133698.626339"
assert say.thread_ts == context.thread_ts
try:
set_status("is typing...")
say("Here you are!")
state["called"] = True
called["value"] = True
except Exception as e:
say(f"Oops, something went wrong (error: {e}")
say(f"Oops, something went wrong (error: {e})")

app.assistant(assistant)

request = BoltRequest(body=thread_started_event_body, mode="socket_mode")
request = BoltRequest(body=user_message_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert_target_called()
assert_target_called(called)

request = BoltRequest(body=thread_context_changed_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert_target_called()
def test_user_message_with_assistant_thread(self):
app = App(client=self.web_client)
assistant = Assistant()
called = {"value": False}

request = BoltRequest(body=user_message_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert_target_called()
@assistant.user_message
def handle_user_message(say: Say, set_status: SetStatus, context: BoltContext):
assert context.channel_id == "D111"
assert context.thread_ts == "1726133698.626339"
assert say.thread_ts == context.thread_ts
try:
set_status("is typing...")
say("Here you are!")
called["value"] = True
except Exception as e:
say(f"Oops, something went wrong (error: {e})")

app.assistant(assistant)

request = BoltRequest(body=user_message_event_body_with_assistant_thread, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert_target_called()
assert_target_called(called)

def test_message_changed(self):
app = App(client=self.web_client)
assistant = Assistant()
called = {"value": False}

@assistant.user_message
def handle_user_message():
called["value"] = True

@assistant.bot_message
def handle_bot_message():
called["value"] = True

app.assistant(assistant)

request = BoltRequest(body=message_changed_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert called["value"] is False

def test_channel_user_message_ignored(self):
app = App(client=self.web_client)
assistant = Assistant()
called = {"value": False}

@assistant.user_message
def handle_user_message():
called["value"] = True

@assistant.bot_message
def handle_bot_message():
called["value"] = True

app.assistant(assistant)

request = BoltRequest(body=channel_user_message_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 404
assert called["value"] is False

def test_channel_message_changed_ignored(self):
app = App(client=self.web_client)
assistant = Assistant()
called = {"value": False}

@assistant.user_message
def handle_user_message():
called["value"] = True

@assistant.bot_message
def handle_bot_message():
called["value"] = True

app.assistant(assistant)

request = BoltRequest(body=channel_message_changed_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 404
assert called["value"] is False


def build_payload(event: dict) -> dict:
Expand Down
Loading
Loading