Skip to content
6 changes: 2 additions & 4 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,9 @@ def get_device(name: str) -> DeviceModel:
return DeviceModel.from_device(device)


def submit_task(task_request: TaskRequest) -> str:
def submit_task(task_request: TaskRequest, metadata: dict[str, Any]) -> str:
"""Submit a task to be run on begin_task"""
metadata: dict[str, Any] = {
"instrument_session": task_request.instrument_session,
}
metadata["instrument_session"] = task_request.instrument_session
if context().tiled_conf is not None:
md = config().env.metadata
# We raise an InvalidConfigError on setting tiled_conf if this isn't set
Expand Down
22 changes: 16 additions & 6 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import urllib.parse
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Annotated
from typing import Annotated, Any

import jwt
from fastapi import (
Expand Down Expand Up @@ -114,7 +114,7 @@ def get_app(config: ApplicationConfig):
)
dependencies = []
if config.oidc:
dependencies.append(Depends(verify_access_token(config.oidc)))
dependencies.append(Depends(decode_access_token(config.oidc)))
app.swagger_ui_init_oauth = {
"clientId": "NOT_SUPPORTED",
}
Expand All @@ -136,24 +136,25 @@ def get_app(config: ApplicationConfig):
return app


def verify_access_token(config: OIDCConfig):
def decode_access_token(config: OIDCConfig):
jwkclient = jwt.PyJWKClient(config.jwks_uri)
oauth_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=config.authorization_endpoint,
tokenUrl=config.token_endpoint,
refreshUrl=config.token_endpoint,
)

def inner(access_token: str = Depends(oauth_scheme)):
def inner(request: Request, access_token: str = Depends(oauth_scheme)):
signing_key = jwkclient.get_signing_key_from_jwt(access_token)
jwt.decode(
decoded: dict[str, Any] = jwt.decode(
access_token,
signing_key.key,
algorithms=config.id_token_signing_alg_values_supported,
verify=True,
audience=config.client_audience,
issuer=config.issuer,
)
request.state.decoded_access_token = decoded

return inner

Expand Down Expand Up @@ -283,7 +284,16 @@ def submit_task(
) -> TaskResponse:
"""Submit a task to the worker."""
try:
task_id: str = runner.run(interface.submit_task, task_request)
# Extract user from jwt if using OIDC (if jwt exists)
access_token: dict[str, Any] | None = getattr(
request.state, "decoded_access_token", None
)
if access_token:
user: str = access_token.get("fedid", "Unknown")
else:
user = "Unknown"

task_id: str = runner.run(interface.submit_task, task_request, {"user": user})
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
except ValidationError as e:
Expand Down
1 change: 1 addition & 0 deletions tests/system_tests/test_blueapi_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def test_instrument_session_propagated(client: BlueapiClient):
response = client.create_task(_SIMPLE_TASK)
trackable_task = client.get_task(response.task_id)
assert trackable_task.task.metadata == {
"user": "alice",
"instrument_session": AUTHORIZED_INSTRUMENT_SESSION,
"tiled_access_tags": [
'{"proposal": 12345, "visit": 1, "beamline": "adsim"}',
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/service/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pathlib import Path
from typing import Any
from unittest.mock import patch
from unittest.mock import Mock, patch

import jwt
import pytest
Expand Down Expand Up @@ -117,18 +117,18 @@ def test_poll_for_token_timeout(
def test_server_raises_exception_for_invalid_token(
oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock
):
inner = main.verify_access_token(oidc_config)
inner = main.decode_access_token(oidc_config)
with pytest.raises(jwt.PyJWTError):
inner(access_token="Invalid Token")
inner(Mock(), access_token="Invalid Token")


def test_processes_valid_token(
oidc_config: OIDCConfig,
mock_authn_server: responses.RequestsMock,
valid_token_with_jwt,
):
inner = main.verify_access_token(oidc_config)
inner(access_token=valid_token_with_jwt["access_token"])
inner = main.decode_access_token(oidc_config)
inner(Mock(), access_token=valid_token_with_jwt["access_token"])


def test_session_cache_manager_returns_writable_file_path(tmp_path):
Expand Down
37 changes: 34 additions & 3 deletions tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_submit_task(context_mock: MagicMock):
mock_uuid_value = "8dfbb9c2-7a15-47b6-bea8-b6b77c31d3d9"
with patch.object(uuid, "uuid4") as uuid_mock:
uuid_mock.return_value = uuid.UUID(mock_uuid_value)
task_uuid = interface.submit_task(task)
task_uuid = interface.submit_task(task, {})
assert task_uuid == mock_uuid_value


Expand All @@ -211,7 +211,7 @@ def test_clear_task(context_mock: MagicMock):
mock_uuid_value = "3d858a62-b40a-400f-82af-8d2603a4e59a"
with patch.object(uuid, "uuid4") as uuid_mock:
uuid_mock.return_value = uuid.UUID(mock_uuid_value)
interface.submit_task(task)
interface.submit_task(task, {})

clear_task_return = interface.clear_task(mock_uuid_value)
assert clear_task_return == mock_uuid_value
Expand Down Expand Up @@ -337,7 +337,8 @@ def test_get_task_by_id(
TaskRequest(
name="my_plan",
instrument_session=FAKE_INSTRUMENT_SESSION,
)
),
{},
)

expected_metadata: dict[str, Any] = {
Expand Down Expand Up @@ -366,6 +367,36 @@ def test_get_task_by_id(
)


@patch("blueapi.service.interface.context")
def test_submit_task_inserts_metadata(context_mock: MagicMock):
context = BlueskyContext()
context.register_plan(my_plan)
context_mock.return_value = context

metadata = {"foo": "bar"}

task_id = interface.submit_task(
TaskRequest(
name="my_plan",
instrument_session=FAKE_INSTRUMENT_SESSION,
),
metadata,
)

assert interface.get_task_by_id(task_id) == TrackableTask.model_construct(
task_id=task_id,
request_id=ANY,
task=Task(
name="my_plan",
params={},
metadata=metadata,
),
is_complete=False,
is_pending=True,
errors=[],
)


@patch("blueapi.service.interface.TiledWriter")
@patch("blueapi.service.interface.from_uri")
@patch("blueapi.service.interface.context")
Expand Down
27 changes: 25 additions & 2 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def client(mock_runner: Mock) -> Iterator[TestClient]:

@pytest.fixture
def client_with_auth(
mock_runner: Mock, oidc_config: OIDCConfig, valid_token_with_jwt: dict[str, Any]
mock_runner: Mock,
oidc_config: OIDCConfig,
valid_token_with_jwt: dict[str, Any],
mock_authn_server,
) -> Iterator[TestClient]:
with patch("blueapi.service.interface.worker"):
main.setup_runner(runner=mock_runner)
Expand Down Expand Up @@ -248,10 +251,30 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None:

response = client.post("/tasks", json=task.model_dump())

mock_runner.run.assert_called_with(submit_task, task)
mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"})
assert response.json() == {"task_id": task_id}


def test_create_task_inserts_auth_metadata(
mock_runner: Mock,
client_with_auth: TestClient,
) -> None:
task = TaskRequest(
name="count",
params={"detectors": ["x"]},
instrument_session=FAKE_INSTRUMENT_SESSION,
)
client_with_auth.follow_redirects = False
task_id = str(uuid.uuid4())

# mock_runner.run.side_effect = [task_id]
mock_runner.run.return_value = [task_id]

client_with_auth.post("/tasks", json=task.model_dump())

mock_runner.run.assert_called_with(submit_task, task, {"user": "jd1"})


def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None:
mock_runner.run.side_effect = [
ValidationError.from_exception_data(
Expand Down