Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 54 additions & 26 deletions src/labthings_fastapi/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections import deque
from functools import partial, wraps
import inspect
from threading import Thread, Lock
from threading import Thread, Lock, RLock
import uuid
from typing import (
TYPE_CHECKING,
Expand All @@ -40,8 +40,7 @@
from weakref import WeakSet
import weakref
from fastapi import APIRouter, FastAPI, HTTPException, Request, Body, BackgroundTasks
from pydantic import BaseModel, create_model

from pydantic import BaseModel, create_model, ValidationError

from .middleware.url_for import URLFor
from .base_descriptor import (
Expand All @@ -50,10 +49,15 @@
DescriptorInfoCollection,
)
from .logs import add_thing_log_destination
from .utilities import model_to_dict, wrap_plain_types_in_rootmodel
from .utilities import (
LabThingsRootModelWrapper,
model_to_dict,
wrap_plain_types_in_rootmodel,
)
from .invocations import InvocationModel, InvocationStatus
from .exceptions import (
GlobalLockBusyError,
InvalidReturnValue,
InvocationCancelledError,
InvocationError,
NotConnectedToServerError,
Expand Down Expand Up @@ -140,9 +144,9 @@
self.expiry_time: Optional[datetime.datetime] = None

# Private state properties
self._status_lock = Lock() # This Lock protects properties below
self._status_lock = RLock() # This Lock protects properties below
self._status: InvocationStatus = InvocationStatus.PENDING # Task status
self._return_value: Optional[Any] = None # Return value
self._output_model_instance: Optional[BaseModel] = None # Return value
self._request_time: datetime.datetime = datetime.datetime.now()
self._start_time: Optional[datetime.datetime] = None # Task start time
self._end_time: Optional[datetime.datetime] = None # Task end time
Expand All @@ -158,7 +162,18 @@
def output(self) -> Any:
"""Return value of the Action. If the Action is still running, returns None."""
with self._status_lock:
return self._return_value
if self._output_model_instance is None:
return None

Check warning on line 166 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

166 line is not covered with tests
if isinstance(self._output_model_instance, LabThingsRootModelWrapper):
return self._output_model_instance.model_dump()
else:
return self._output_model_instance

Check warning on line 170 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

170 line is not covered with tests

@property
def output_model_instance(self) -> BaseModel | None:
"""Return value of the Action, as a model, or None."""
with self._status_lock:
return self._output_model_instance

@property
def log(self) -> list[logging.LogRecord]:
Expand Down Expand Up @@ -233,19 +248,20 @@
]
# The line below confuses MyPy because self.action **evaluates to** a Descriptor
# object (i.e. we don't call __get__ on the descriptor).
return self.action.invocation_model( # type: ignore[attr-defined]
status=self.status,
id=self.id,
action=self.thing.path + self.action.name, # type: ignore[attr-defined]
href=URLFor("action_invocation", id=self.id),
timeStarted=self._start_time,
timeCompleted=self._end_time,
timeRequested=self._request_time,
input=self.input,
output=self.output,
links=links,
log=self.log,
)
with self._status_lock:
return self.action.invocation_model( # type: ignore[attr-defined]
status=self.status,
id=self.id,
action=self.thing.path + self.action.name, # type: ignore[attr-defined]
href=URLFor("action_invocation", id=self.id),
timeStarted=self._start_time,
timeCompleted=self._end_time,
timeRequested=self._request_time,
input=self.input,
output=self.output_model_instance,
links=links,
log=self.log,
)

def run(self) -> None:
"""Run the action and track progress.
Expand Down Expand Up @@ -273,6 +289,8 @@
See `.Invocation.status` for status values.

:raises RuntimeError: if there is no Thing associated with the invocation.
:raises InvalidReturnValue: if the action returns a value that can't
be validated by its output model.
"""
# self.action evaluates to an ActionDescriptor. This confuses mypy,
# which thinks we are calling ActionDescriptor.__get__.
Expand Down Expand Up @@ -303,19 +321,29 @@
# Actually run the action
ret = action.func(thing, **kwargs, **self.dependencies)

with self._status_lock:
self._return_value = ret
self._status = InvocationStatus.COMPLETED
action.emit_changed_event(self.thing, self._status.value)
try:
output_model_instance = action.output_model.model_validate(ret)
except ValidationError as e:
# Generate a helpful error message. This will be handled below,
# where it will cause the action to be marked as failed, and the
# error will end up in the log.
msg = f"The return value from '{self.thing.name}.{action.name}' "
msg += "failed to validate against its output model "
msg += f"'{action.output_model}'. The return value was '{ret}'."
raise InvalidReturnValue(msg) from e

with self._status_lock:
self._output_model_instance = output_model_instance
self._status = InvocationStatus.COMPLETED
action.emit_changed_event(self.thing, self._status.value)
except InvocationCancelledError:
logger.info(f"Invocation {self.id} was cancelled.")
with self._status_lock:
self._status = InvocationStatus.CANCELLED
action.emit_changed_event(self.thing, self._status.value)
except Exception as e: # skipcq: PYL-W0703
# First log
if isinstance(e, InvocationError):
if isinstance(e, (InvocationError, InvalidReturnValue)):
# Log without traceback for anticipated errors
logger.error(e)
elif (
Expand Down Expand Up @@ -521,8 +549,8 @@
with self._invocations_lock:
try:
invocation: Any = self._invocations[id]
except KeyError as e:
raise HTTPException(

Check warning on line 553 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

552-553 lines are not covered with tests
status_code=404,
detail="No action invocation found with ID {id}",
) from e
Expand All @@ -535,8 +563,8 @@
invocation.output.response
):
# TODO: honour "accept" header
return invocation.output.response()

Check warning on line 566 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

566 line is not covered with tests
return invocation.output
return invocation.output_model_instance

@router.delete(
ACTION_INVOCATIONS_PATH + "/{id}",
Expand All @@ -560,8 +588,8 @@
with self._invocations_lock:
try:
invocation: Any = self._invocations[id]
except KeyError as e:
raise HTTPException(

Check warning on line 592 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

591-592 lines are not covered with tests
status_code=404,
detail="No action invocation found with ID {id}",
) from e
Expand Down Expand Up @@ -748,7 +776,7 @@
"""
super().__set_name__(owner, name)
if self.name != self.func.__name__:
raise ValueError(

Check warning on line 779 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

779 line is not covered with tests
f"Action name '{self.name}' does not match function name "
f"'{self.func.__name__}'",
)
Expand Down Expand Up @@ -940,14 +968,14 @@
try:
responses[200]["model"] = self.output_model
pass
except AttributeError:
print(f"Failed to generate response model for action {self.name}")

Check warning on line 972 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

971-972 lines are not covered with tests
# Add an additional media type if we may return a file
if hasattr(self.output_model, "media_type"):
responses[200]["content"][self.output_model.media_type] = {}

Check warning on line 975 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

975 line is not covered with tests
# Now we can add the endpoint to the app.
if thing.path is None:
raise NotConnectedToServerError(

Check warning on line 978 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

978 line is not covered with tests
"Can't add the endpoint without thing.path!"
)
app.post(
Expand Down Expand Up @@ -995,7 +1023,7 @@
"""
path = path or thing.path
if path is None:
raise NotConnectedToServerError("Can't generate forms without a path!")

Check warning on line 1026 in src/labthings_fastapi/actions.py

View workflow job for this annotation

GitHub Actions / coverage

1026 line is not covered with tests
forms = [
Form[ActionOp](href=path + self.name, op=[ActionOp.invokeaction]),
]
Expand Down
17 changes: 13 additions & 4 deletions src/labthings_fastapi/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,24 @@ def poll_invocation(
:param first_interval: sets how long we wait before the first
polling request. Often, it makes sense for this to be a short
interval, in case the action fails (or returns) immediately.

:raises ServerActionError: if an HTTP error is found during polling.
:return: the completed invocation as a dictionary.
"""
first_time = True
while invocation["status"] in ACTION_RUNNING_KEYWORDS:
time.sleep(first_interval if first_time else interval)
r = client.get(invocation_href(invocation))
r.raise_for_status()
invocation = r.json()
response = client.get(invocation_href(invocation))
if response.is_error:
try:
message = response.json()["detail"]
except KeyError:
message = response.text
raise ServerActionError(
f"The server returned error {response.status_code} while polling "
f"action '{invocation['action']}' with id '{invocation['id']}'. "
f"The error message was:\n{message}."
)
invocation = response.json()
first_time = False
return invocation

Expand Down
17 changes: 17 additions & 0 deletions src/labthings_fastapi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,23 @@ class NoInvocationContextError(RuntimeError):
"""


class InvalidReturnValue(RuntimeError):
r"""The return value from a method cannot be serialised by LabThings.

This error is raised when an action returns a value that can't be serialised.
This usually means that either it doesn't match the declared return type of
the function, or the declared return type permits un-serialisable values.

If an action's return type is missing or `Any`\ , it's possible to return a
value that can't be serialised, which will cause this error.

The solution is usually to ensure that the return type of your action is
either a simple type that can be serialised to JSON, or a Pydantic model.
You should also check that the function's return value matches the declared
type, ideally by regularly running a type checker like `mypy` on your code.
"""


class LogConfigurationError(RuntimeError):
"""There is a problem with logging configuration.

Expand Down
32 changes: 31 additions & 1 deletion src/labthings_fastapi/invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
from typing import Optional, Any, Sequence, TypeVar, Generic
import uuid

from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import (
BaseModel,
ConfigDict,
model_validator,
model_serializer,
SerializerFunctionWrapHandler,
)
from pydantic_core import PydanticSerializationError

from labthings_fastapi.middleware.url_for import URLFor

Expand Down Expand Up @@ -105,6 +112,29 @@ class GenericInvocationModel(BaseModel, Generic[InputT, OutputT]):
log: Sequence[LogRecordModel]
links: Links = None

@model_serializer(mode="wrap")
def serialize_model(
self, handler: SerializerFunctionWrapHandler
) -> dict[str, object]:
"""Give a more helpful error if the class fails to serialize.

:param handler: The Pydantic serializer.
:raises PydanticSerializationError: if the model fails to serialize. This
is wrapped to add the action and invocation ID.
:return: the serialized model, as a dictionary.
"""
try:
return handler(self)
except PydanticSerializationError as e:
extra = ""
if self.output is not None:
extra = "This is often caused by an invalid return value. "
raise PydanticSerializationError(
f"Could not serialise invocation '{self.id}' of '{self.action}' "
f"({self.status.value}). {extra}"
f"Error: '{e}'"
) from e


InvocationModel = GenericInvocationModel[Any, Any]
"""A model to serialise `.Invocation` objects when they are polled over HTTP."""
14 changes: 14 additions & 0 deletions src/labthings_fastapi/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from fastapi.testclient import TestClient
from pydantic import ValidationError
from pydantic_core import PydanticSerializationError
from typing import Any, AsyncGenerator, Optional, TypeVar, overload
from fastapi.responses import JSONResponse
from typing_extensions import Self
Expand Down Expand Up @@ -50,6 +51,9 @@
ThingSubclass = TypeVar("ThingSubclass", bound=Thing)


LOGGER = logging.getLogger(__name__)


class ThingServer:
"""Use FastAPI to serve `~lt.Thing` instances.

Expand Down Expand Up @@ -248,6 +252,16 @@ async def global_lock_exception_handler(
content={"detail": repr(exc)},
)

@self.app.exception_handler(PydanticSerializationError)
async def serialization_error_handler(
request: Request, exc: PydanticSerializationError
) -> JSONResponse:
LOGGER.error(
f"Couldn't serialize response to {request.url} because of error: \n"
f"{exc}"
)
return JSONResponse(status_code=500, content={"detail": str(exc)})

@property
def debug(self) -> bool:
"""Whether the server is in debug mode."""
Expand Down
39 changes: 39 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any
import uuid
from fastapi.testclient import TestClient
from labthings_fastapi.exceptions import FailedToInvokeActionError, ServerActionError
from pydantic import BaseModel
import pytest
import functools
Expand Down Expand Up @@ -333,3 +335,40 @@ def long_docstring(self) -> None:
assert actions["long_docstring"].description.startswith(
"It has multiple paragraphs."
)


def test_invalid_return_values():
"""Test the errors raised when an action's return value can't be serialised."""

class NaughtyThing(lt.Thing):
@lt.action
def make_random_int(self) -> int:
"""An action that should return an integer, but returns a float."""
return 4.2

@lt.action
def make_unjsonable_any(self) -> Any:
"""A vaguely-typed action that won't serialise."""
return object()

server = lt.ThingServer.from_things({"naughty": NaughtyThing})
with server.test_client() as client:
tc = lt.ThingClient.from_url("/naughty/", client=client)

# If a return type doesn't match the type hint, we get
with pytest.raises(
(ServerActionError, FailedToInvokeActionError),
match=(
r"The return value from 'naughty.make_random_int' failed to validate "
r"against its output model."
),
):
tc.make_random_int()

# If a return type is not JSONable
with pytest.raises(
(ServerActionError, FailedToInvokeActionError),
match="Could not serialise invocation",
) as excinfo:
tc.make_unjsonable_any()
assert "make_unjsonable_any" in str(excinfo)
Loading