Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from dataclasses import InitVar, dataclass, field
from importlib import import_module
from inspect import Parameter, signature
from types import ModuleType, NoneType, UnionType
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints

from bluesky.protocols import HasName
from bluesky.run_engine import RunEngine
from dodal.common.beamlines.beamline_utils import get_path_provider, set_path_provider
from dodal.utils import make_all_devices
from ophyd_async.core import NotConnected
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, create_model
Expand All @@ -16,12 +17,15 @@
from pydantic_core import CoreSchema, core_schema

from blueapi import utils
from blueapi.config import EnvironmentConfig, SourceKind
from blueapi.client.numtracker import NumtrackerClient
from blueapi.config import ApplicationConfig, EnvironmentConfig, SourceKind
from blueapi.utils import (
BlueapiPlanModelConfig,
is_function_sourced_from_module,
load_module_all,
)
from blueapi.utils.invalid_config_error import InvalidConfigError
from blueapi.utils.path_provider import StartDocumentPathProvider

from .bluesky_types import (
BLUESKY_PROTOCOLS,
Expand Down Expand Up @@ -86,15 +90,57 @@ class BlueskyContext:
The context holds the RunEngine and any plans/devices that you may want to use.
"""

configuration: InitVar[ApplicationConfig | None] = None

run_engine: RunEngine = field(
default_factory=lambda: RunEngine(context_managers=[])
)
numtracker: NumtrackerClient | None = field(default=None, init=False, repr=False)
plans: dict[str, Plan] = field(default_factory=dict)
devices: dict[str, Device] = field(default_factory=dict)
plan_functions: dict[str, PlanGenerator] = field(default_factory=dict)

_reference_cache: dict[type, type] = field(default_factory=dict)

def __post_init__(self, configuration: ApplicationConfig | None):
if not configuration:
return

if configuration.numtracker is not None:
if configuration.env.metadata is not None:
self.numtracker = NumtrackerClient(url=configuration.numtracker.url)
else:
raise InvalidConfigError(
"Numtracker url has been configured, but there is no instrument or"
" instrument_session in the environment metadata"
)

if self.numtracker is not None:
numtracker = self.numtracker

path_provider = StartDocumentPathProvider()
set_path_provider(path_provider)
self.run_engine.subscribe(path_provider.update_run, "start")

def _update_scan_num(md: dict[str, Any]) -> int:
scan = numtracker.create_scan(
md["instrument_session"], md["instrument"]
)
md["data_session_directory"] = str(scan.scan.directory.path)
md["scan_file"] = scan.scan.scan_file
return scan.scan.scan_number

self.run_engine.scan_id_source = _update_scan_num

self.with_config(configuration.env)
if self.numtracker and not isinstance(
get_path_provider(), StartDocumentPathProvider
):
raise InvalidConfigError(
"Numtracker has been configured but a path provider was imported with "
"the devices. Remove this path provider to use numtracker."
)

def find_device(self, addr: str | list[str]) -> Device | None:
"""
Find a device in this context, allows for recursive search.
Expand Down
76 changes: 3 additions & 73 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@

from bluesky_stomp.messaging import StompClient
from bluesky_stomp.models import Broker, DestinationBase, MessageTopic
from dodal.common.beamlines.beamline_utils import (
get_path_provider,
set_path_provider,
)

from blueapi.cli.scratch import get_python_environment
from blueapi.client.numtracker import NumtrackerClient
from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig
from blueapi.core.context import BlueskyContext
from blueapi.core.event import EventStream
Expand All @@ -23,8 +18,6 @@
TaskRequest,
WorkerTask,
)
from blueapi.utils.invalid_config_error import InvalidConfigError
from blueapi.utils.path_provider import StartDocumentPathProvider
from blueapi.worker.event import TaskStatusEnum, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TaskWorker, TrackableTask
Expand All @@ -48,14 +41,10 @@ def set_config(new_config: ApplicationConfig):

@cache
def context() -> BlueskyContext:
ctx = BlueskyContext()
ctx = BlueskyContext(config())
return ctx


def configure_context() -> None:
context().with_config(config().env)


@cache
def worker() -> TaskWorker:
worker = TaskWorker(
Expand Down Expand Up @@ -96,76 +85,23 @@ def stomp_client() -> StompClient | None:
return None


@cache
def numtracker_client() -> NumtrackerClient | None:
conf = config()
if conf.numtracker is not None:
if conf.env.metadata is not None:
return NumtrackerClient(url=conf.numtracker.url)
else:
raise InvalidConfigError(
"Numtracker url has been configured, but there is no instrument or"
" instrument_session in the environment metadata"
)
else:
return None


def _update_scan_num(md: dict[str, Any]) -> int:
numtracker = numtracker_client()
if numtracker is not None:
scan = numtracker.create_scan(md["instrument_session"], md["instrument"])
md["data_session_directory"] = str(scan.scan.directory.path)
md["scan_file"] = scan.scan.scan_file
return scan.scan.scan_number
else:
raise InvalidConfigError(
"Blueapi was configured to talk to numtracker but numtracker is not"
"configured, this should not happen, please contact the DAQ team"
)


def setup(config: ApplicationConfig) -> None:
"""Creates and starts a worker with supplied config"""
set_config(config)
set_up_logging(config.logging)

# Eagerly initialize worker and messaging connection
worker()

# if numtracker is configured, use a StartDocumentPathProvider
if numtracker_client() is not None:
context().run_engine.scan_id_source = _update_scan_num
_hook_run_engine_and_path_provider()

configure_context()

if numtracker_client() is not None and not isinstance(
get_path_provider(), StartDocumentPathProvider
):
raise InvalidConfigError(
"Numtracker has been configured but a path provider was imported"
" with the devices. Remove this path provider to use numtracker."
)

stomp_client()


def _hook_run_engine_and_path_provider() -> None:
path_provider = StartDocumentPathProvider()
set_path_provider(path_provider)
run_engine = context().run_engine
run_engine.subscribe(path_provider.update_run, "start")


def teardown() -> None:
worker().stop()
if (stomp_client_ref := stomp_client()) is not None:
stomp_client_ref.disconnect()
context.cache_clear()
worker.cache_clear()
stomp_client.cache_clear()
numtracker_client.cache_clear()


def _publish_event_streams(
Expand Down Expand Up @@ -224,19 +160,13 @@ def begin_task(
task: WorkerTask, pass_through_headers: Mapping[str, str] | None = None
) -> WorkerTask:
"""Trigger a task. Will fail if the worker is busy"""
_try_configure_numtracker(pass_through_headers or {})

if nt := context().numtracker:
nt.set_headers(pass_through_headers or {})
if task.task_id is not None:
worker().begin_task(task.task_id)
return task


def _try_configure_numtracker(pass_through_headers: Mapping[str, str]) -> None:
numtracker = numtracker_client()
if numtracker is not None:
numtracker.set_headers(pass_through_headers)


def get_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]:
"""Retrieve a list of tasks based on their status."""
return worker().get_tasks_by_status(status)
Expand Down
65 changes: 35 additions & 30 deletions tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,15 @@ def mock_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]:
assert interface.get_tasks_by_status(TaskStatusEnum.COMPLETE) == []


@patch("blueapi.service.interface._try_configure_numtracker")
@patch("blueapi.service.interface.BlueskyContext.numtracker")
@patch("blueapi.service.interface.TaskWorker.begin_task")
def test_begin_task_with_headers(worker_mock: MagicMock, mock_configure: MagicMock):
def test_begin_task_with_headers(worker_mock: MagicMock, mock_numtracker: MagicMock):
uuid_value = "350043fd-597e-41a7-9a92-5d5478232cf7"
task = WorkerTask(task_id=uuid_value)
headers = {"a": "b"}

returned_task = interface.begin_task(task, headers)
mock_configure.assert_called_once_with(headers)
mock_numtracker.set_headers.assert_called_once_with(headers)

assert task == returned_task
worker_mock.assert_called_once_with(uuid_value)
Expand Down Expand Up @@ -406,10 +406,10 @@ def test_configure_numtracker():
)
interface.set_config(conf)
headers = {"a": "b"}
interface._try_configure_numtracker(headers)
nt = interface.numtracker_client()
nt = interface.context().numtracker

assert isinstance(nt, NumtrackerClient)
nt.set_headers(headers)
assert nt._headers == {"a": "b"}
assert nt._url.unicode_string() == "https://numtracker-example.com/graphql"

Expand Down Expand Up @@ -443,37 +443,36 @@ def test_headers_are_cleared(mock_post):
headers = {"foo": "bar"}

interface.begin_task(task=WorkerTask(task_id=None), pass_through_headers=headers)
interface._update_scan_num({"instrument_session": "cm12345-1", "instrument": "p46"})
ctx = interface.context()
assert ctx.run_engine.scan_id_source is not None
ctx.run_engine.scan_id_source(
{"instrument_session": "cm12345-1", "instrument": "p46"}
)
mock_post.assert_called_once()
assert mock_post.call_args.kwargs["headers"] == headers

interface.begin_task(task=WorkerTask(task_id=None))
interface._update_scan_num({"instrument_session": "cm12345-1", "instrument": "p46"})
ctx.run_engine.scan_id_source(
{"instrument_session": "cm12345-1", "instrument": "p46"}
)
assert mock_post.call_count == 2
assert mock_post.call_args.kwargs["headers"] == {}


def test_configure_numtracker_with_no_numtracker_config_fails():
def test_numtracker_requires_instrument_metadata():
conf = ApplicationConfig(
env=EnvironmentConfig(metadata=MetadataConfig(instrument="p46")),
numtracker=NumtrackerConfig(
url=HttpUrl("https://numtracker-example.com/graphql"),
)
)
interface.set_config(conf)
headers = {"a": "b"}
interface._try_configure_numtracker(headers)
nt = interface.numtracker_client()

assert nt is None


def test_configure_numtracker_with_no_metadata_fails():
conf = ApplicationConfig(numtracker=NumtrackerConfig())
interface.set_config(conf)
headers = {"a": "b"}

assert conf.env.metadata is None

print("Post config")
with pytest.raises(InvalidConfigError):
interface._try_configure_numtracker(headers)
interface.context()

# Clearing the config here prevents the same exception as above being
# raised in the ensure_worker_stopped fixture
interface.set_config(ApplicationConfig())


def test_setup_without_numtracker_with_existing_provider_does_not_overwrite_provider():
Expand Down Expand Up @@ -506,7 +505,6 @@ def test_setup_with_numtracker_makes_start_document_provider():
path_provider = get_path_provider()

assert isinstance(path_provider, StartDocumentPathProvider)
assert interface.context().run_engine.scan_id_source == interface._update_scan_num

clear_path_provider()

Expand Down Expand Up @@ -545,12 +543,15 @@ def test_numtracker_create_scan_called_with_arguments_from_metadata(mock_create_
)
interface.set_config(conf)
ctx = interface.context()
interface.configure_context()

headers = {"a": "b"}
interface._try_configure_numtracker(headers)

assert ctx.numtracker is not None
assert ctx.run_engine.scan_id_source is not None

ctx.numtracker.set_headers(headers)
ctx.run_engine.md["instrument_session"] = "ab123"
interface._update_scan_num(ctx.run_engine.md)
ctx.run_engine.scan_id_source(ctx.run_engine.md)

mock_create_scan.assert_called_once_with("ab123", "p46")

Expand All @@ -567,8 +568,10 @@ def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md(
interface.setup(conf)
ctx = interface.context()

assert ctx.run_engine.scan_id_source is not None

ctx.run_engine.md["instrument_session"] = "ab123"
interface._update_scan_num(ctx.run_engine.md)
ctx.run_engine.scan_id_source(ctx.run_engine.md)

assert (
ctx.run_engine.md["data_session_directory"] == "/exports/mybeamline/data/2025"
Expand All @@ -587,7 +590,9 @@ def test_update_scan_num_side_effect_sets_scan_file_in_re_md(
interface.setup(conf)
ctx = interface.context()

assert ctx.run_engine.scan_id_source is not None

ctx.run_engine.md["instrument_session"] = "ab123"
interface._update_scan_num(ctx.run_engine.md)
ctx.run_engine.scan_id_source(ctx.run_engine.md)

assert ctx.run_engine.md["scan_file"] == "p46-11"