Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,20 @@ def begin_task(
task: WorkerTask, pass_through_headers: Mapping[str, str] | None = None
) -> WorkerTask:
"""Trigger a task. Will fail if the worker is busy"""
if nt := context().numtracker:

active_worker = worker()
active_context = context()
if nt := active_context.numtracker:
nt.set_headers(pass_through_headers or {})

if tiled_config := context().tiled_conf:
if tiled_config := active_context.tiled_conf:
# Tiled queries the root node, so must create an authorized client
tiled_client = from_uri(
str(tiled_config.url),
api_key=tiled_config.api_key,
headers=pass_through_headers,
)
tiled_writer_token = context().run_engine.subscribe(
tiled_writer_token = active_context.run_engine.subscribe(
TiledWriter(tiled_client, batch_size=1)
)

Expand All @@ -195,12 +198,15 @@ def remove_callback_when_task_finished(
and event.task_status.task_id == task.task_id
and event.task_status.task_complete
):
context().run_engine.unsubscribe(tiled_writer_token)
active_context.run_engine.unsubscribe(tiled_writer_token)
active_worker.worker_events.unsubscribe(remove_callback)

worker().worker_events.subscribe(remove_callback_when_task_finished)
remove_callback = active_worker.worker_events.subscribe(
remove_callback_when_task_finished
)

if task.task_id is not None:
worker().begin_task(task.task_id)
active_worker.begin_task(task.task_id)
return task


Expand Down
51 changes: 50 additions & 1 deletion tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PlanSource,
ScratchConfig,
StompConfig,
TiledConfig,
)
from blueapi.core.context import BlueskyContext
from blueapi.service import interface
Expand All @@ -42,7 +43,7 @@
)
from blueapi.utils.invalid_config_error import InvalidConfigError
from blueapi.utils.path_provider import StartDocumentPathProvider
from blueapi.worker.event import TaskStatusEnum, WorkerState
from blueapi.worker.event import TaskStatus, TaskStatusEnum, WorkerEvent, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TrackableTask

Expand Down Expand Up @@ -365,6 +366,54 @@ def test_get_task_by_id(
)


@patch("blueapi.service.interface.TiledWriter")
@patch("blueapi.service.interface.from_uri")
@patch("blueapi.service.interface.context")
@patch("blueapi.service.interface.worker")
def test_remove_tiled_subscriber(worker, context, from_uri, writer):
task = WorkerTask(task_id="foo_bar")
context().numtracker = None
context().tiled_conf = TiledConfig()
context().run_engine.subscribe.return_value = 17
worker().worker_events.subscribe.return_value = 42

interface.begin_task(task)

writer.assert_called_once_with(from_uri(), batch_size=1)
context().run_engine.subscribe.assert_called_once_with(writer())
worker().worker_events.subscribe.assert_called_once()

inner_callback = worker().worker_events.subscribe.call_args.args[0]

inner_callback(
WorkerEvent(
state=WorkerState.RUNNING,
task_status=TaskStatus(
task_id="foo_bar",
task_complete=False,
task_failed=False,
),
),
"c_id",
)
context().run_engine.unsubscribe.assert_not_called()
worker().worker_events.unsubscribe.assert_not_called()

inner_callback(
WorkerEvent(
state=WorkerState.IDLE,
task_status=TaskStatus(
task_id="foo_bar",
task_complete=True,
task_failed=False,
),
),
"c_id",
)
context().run_engine.unsubscribe.assert_called_once_with(17)
worker().worker_events.unsubscribe.assert_called_once_with(42)


def test_get_oidc_config(oidc_config: OIDCConfig):
interface.set_config(ApplicationConfig(oidc=oidc_config))
assert interface.get_oidc_config() == oidc_config
Expand Down