Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
PlanModel,
PythonEnvironmentResponse,
SourceInfo,
TaskParamsValidationRequest,
TaskRequest,
WorkerTask,
)
Expand Down Expand Up @@ -172,6 +173,22 @@ def submit_task(
return worker().submit_task(task)


def validate_task_params(
task_request: TaskParamsValidationRequest, metadata: dict[str, Any] | None = None
) -> bool:
"""Validate the params for a task"""
# Can't default arg to mutable data structure:
if metadata is None:
metadata = {}

task = Task(
name=task_request.name,
params=task_request.params,
metadata=metadata,
)
return worker().validate_task_params(task)


def clear_task(task_id: str) -> str:
"""Remove a task from the worker"""
return worker().clear_task(task_id)
Expand Down
60 changes: 60 additions & 0 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@
PythonEnvironmentResponse,
SourceInfo,
StateChangeRequest,
TaskParamsValidationRequest,
TaskRequest,
TaskResponse,
TasksListResponse,
TasksParamValidationResponse,
WorkerTask,
)
from .runner import WorkerDispatcher
Expand Down Expand Up @@ -327,6 +329,64 @@ def submit_task(
) from e


example_task_validate_params_request = TaskParamsValidationRequest(
name="count",
params={"detectors": ["x"]},
)


@secure_router_v1.post(
"/validateTaskParams", status_code=status.HTTP_200_OK, tags=[Tag.TASK]
)
@start_as_current_span(
TRACER,
"request",
"task_request.name",
"task_request.params",
)
def validate_task_params(
request: Request,
response: Response,
task_request: Annotated[
TaskParamsValidationRequest, Body(..., examples=[example_task_request])
],
runner: Annotated[WorkerDispatcher, Depends(_runner)],
) -> TasksParamValidationResponse:
"""Validate the tasks parameters."""
try:
# Extract user from jwt if using OIDC (if jwt exists)
access_token: dict[str, Any] | None = getattr(
request.state, "decoded_access_token", None
)
if access_token:
user: str = access_token.get("fedid", "Unknown")
else:
user = "Unknown"

validated: bool = runner.run(
interface.validate_task_params, task_request, {"user": user}
)
return TasksParamValidationResponse(validated=validated)
except ValidationError as e:
# Add body/params context to location and ensure that all required
# fields defined in the generated schema are present
errors = [
{
"loc": ["body", "params", *err.get("loc", [])],
"msg": err.get("msg", None),
"type": err.get("type", None),
# Input is not listed as required but is useful to have if available
"input": err.get("input", None),
}
for err in e.errors()
]

raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=errors,
) from e


@secure_router_v1.delete(
"/tasks/{task_id}", status_code=status.HTTP_200_OK, tags=[Tag.TASK]
)
Expand Down
21 changes: 21 additions & 0 deletions src/blueapi/service/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ class TasksListResponse(BlueapiBaseModel):
tasks: list[TrackableTask] = Field(description="List of tasks")


class TasksParamValidationResponse(BlueapiBaseModel):
"""
Diagnostic information on the tasks
"""

validated: bool = Field(
description="Whether the task params were sucessfully validated"
)


class TaskRequest(BlueapiBaseModel):
"""
Request to run a task with related info
Expand All @@ -72,6 +82,17 @@ class TaskRequest(BlueapiBaseModel):
)


class TaskParamsValidationRequest(BlueapiBaseModel):
"""
Request to validate the parameters of a task
"""

name: str = Field(description="Name of plan to run")
params: Mapping[str, Any] = Field(
description="Values for parameters to plan, if any", default_factory=dict
)


class DeviceRequest(BlueapiBaseModel):
"""
A query for devices
Expand Down
12 changes: 12 additions & 0 deletions src/blueapi/worker/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,18 @@ def submit_task(self, task: Task) -> str:
self._pending_tasks[task_id] = trackable_task
return task_id

@start_as_current_span(TRACER, "task.name", "task.params")
def validate_task_params(self, task: Task) -> bool:
"""
Validates the params for a task
Args:
task: A description of the task
Returns:
bool: True of the params are validated
"""
task.prepare_params(self._ctx) # Will raise if parameters are invalid
return True

@start_as_current_span(
TRACER,
"trackable_task.task_id",
Expand Down
Loading