Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions custom_components/pyscript/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ async def _call(self, data: DispatchData) -> None:
for result_handler_dec in result_handlers:
await result_handler_dec.handle_call_result(data, result)
except Exception as e:
for result_handler_dec in result_handlers:
await result_handler_dec.handle_call_result(data, None)
await self.handle_exception(e)

async def dispatch(self, data: DispatchData) -> None:
Expand All @@ -290,6 +292,8 @@ async def dispatch(self, data: DispatchData) -> None:
for dec in decorators:
if await dec.handle_dispatch(data) is False:
self.logger.debug("Trigger not active due to %s", dec)
for result_handler_dec in self.get_decorators(CallResultHandlerDecorator):
await result_handler_dec.handle_call_result(data, None)
return

action_ast_ctx = AstEval(
Expand Down
69 changes: 63 additions & 6 deletions custom_components/pyscript/decorators/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@

from __future__ import annotations

import asyncio
import logging
from typing import ClassVar
from typing import Any, ClassVar

from aiohttp import hdrs
from aiohttp import hdrs, web
import voluptuous as vol

from homeassistant.components import webhook
from homeassistant.components.webhook import SUPPORTED_METHODS
from homeassistant.helpers import config_validation as cv

from ..decorator_abc import DispatchData, TriggerDecorator
from ..decorator_abc import CallResultHandlerDecorator, DispatchData, TriggerDecorator
from .base import AutoKwargsDecorator, ExpressionDecorator

_LOGGER = logging.getLogger(__name__)


class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator):
class WebhookTriggerDecorator(
TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator, CallResultHandlerDecorator
):
"""Implementation for @webhook_trigger."""

name = "webhook_trigger"
Expand All @@ -32,12 +35,15 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD
{
vol.Optional("local_only", default=True): cv.boolean,
vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]),
vol.Optional("sets_http_response_code", default=False): cv.boolean,
}
)

webhook_id: str
local_only: bool
methods: set[str]
sets_http_response_code: bool
future: asyncio.Future[Any] | None = None

webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {}

Expand All @@ -50,7 +56,7 @@ async def validate(self):
self.create_expression(self.args[1])

@staticmethod
async def _handler(_hass, webhook_id, request):
async def _handler(hass, webhook_id, request):
func_args = {
"trigger_type": "webhook",
"webhook_id": webhook_id,
Expand All @@ -64,17 +70,68 @@ async def _handler(_hass, webhook_id, request):
payload_multidict = await request.post()
func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()}

response_future: asyncio.Future[Any] | None = None
futures: list[asyncio.Future[Any]] = []
for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy():
trigger_args = func_args.copy()
if trigger.has_expression():
if not await trigger.check_expression_vars(trigger_args):
continue
future: asyncio.Future[Any] = hass.loop.create_future()
trigger.future = future
if trigger.sets_http_response_code:
response_future = future
futures.append(future)
await trigger.dispatch(DispatchData(trigger_args))

if not futures:
return None

await asyncio.gather(*futures, return_exceptions=True)

if response_future is None:
return None
return WebhookTriggerDecorator.coerce_response(response_future.result())

async def handle_call_result(self, data: DispatchData, result: Any) -> None:
"""Resolve the response future with the decorated function's return value."""
if data.trigger is not self:
return
response_future = self.future
if response_future is not None and not response_future.done():
response_future.set_result(result)

@staticmethod
def coerce_response(value: Any) -> web.Response | None:
"""Convert a webhook function return value to an aiohttp Response."""
if value is None:
return None
if isinstance(value, web.Response):
return value
# bool is a subclass of int; reject it so True/False don't become 1/0 status codes.
if isinstance(value, int) and not isinstance(value, bool):
return web.Response(status=value)
_LOGGER.warning(
"webhook function returned unsupported type %s; expected int status code or aiohttp.web.Response",
type(value).__name__,
)
return None

@staticmethod
def _add_trigger(trigger: WebhookTriggerDecorator) -> None:
webhook_id = trigger.webhook_id
if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers:
existing = WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id)
if (
trigger.sets_http_response_code
and existing is not None
and any(t.sets_http_response_code for t in existing)
):
raise ValueError(
f"webhook_id '{webhook_id}' already has a @webhook_trigger with "
f"sets_http_response_code=True; only one is allowed"
)

if existing is None:
webhook.async_register(
trigger.dm.hass,
"pyscript", # DOMAIN
Expand Down
2 changes: 2 additions & 0 deletions custom_components/pyscript/stubs/pyscript_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def webhook_trigger(
str_expr: str | None = None,
local_only: bool = True,
methods: set[SUPPORTED_METHODS] | list[SUPPORTED_METHODS] = {"POST", "PUT"},
sets_http_response_code: bool = False,
kwargs: dict | None = None,
) -> Callable[..., Any]:
"""Trigger when a request is made to a webhook endpoint.
Expand All @@ -136,6 +137,7 @@ def webhook_trigger(
str_expr: Optional expression evaluated against ``trigger_type``, ``webhook_id``, ``request``, and ``payload``.
local_only: If False, allow requests from anywhere on the internet.
methods: HTTP methods to allow.
sets_http_response_code: If True, the function's return value drives the HTTP response (``int`` status code or ``aiohttp.web.Response``); at most one trigger per ``webhook_id`` may set this.
kwargs: Extra keyword arguments merged into each invocation.

Trigger kwargs include ``trigger_type="webhook"``, ``webhook_id``, the parsed payload fields, and ``request`` (the underlying ``aiohttp.web.Request``).
Expand Down
12 changes: 12 additions & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,18 @@ To validate an HMAC signature on incoming requests, declare ``request`` in the f
return
log.info(f"verified webhook: {payload}")

To control the HTTP response sent back to the webhook caller, opt in by passing ``sets_http_response_code=True``. The flagged function's return value then drives the response: ``None`` produces a ``200 OK``, an ``int`` sends back a response with that status code, and an ``aiohttp.web.Response`` allows full control over the body and headers. Return values from triggers without the flag are ignored. For example:

.. code:: python

@webhook_trigger("myid", sets_http_response_code=True)
def webhook_check(payload):
if "token" not in payload:
return 401
return 204

At most one ``@webhook_trigger`` per ``webhook_id`` may set ``sets_http_response_code=True``; declaring more than one is an error at setup time. The webhook handler waits for all decorated function(s) for the ``webhook_id`` to finish before responding, so use ``task.create()`` to fire-and-forget any long-running work.

NOTE: A webhook_id can only be used by either a built-in Home Assistant automation or pyscript, but not both. Trying to use the same webhook_id in both will result in an error.

@state_active
Expand Down
30 changes: 29 additions & 1 deletion tests/test_decorator_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,11 @@ def make_dispatch_data(
hass_context: Context | None = None,
) -> DispatchData:
"""Build DispatchData from test doubles."""
return DispatchData(func_args, call_ast_ctx=call_ast_ctx, hass_context=hass_context)
return DispatchData(
func_args,
call_ast_ctx=call_ast_ctx,
hass_context=hass_context,
)


def setup_global_context_function_hass(hass: HomeAssistant, config_data: dict | None = None) -> None:
Expand Down Expand Up @@ -599,6 +603,30 @@ async def test_function_decorator_manager_logs_call_exception(hass):
assert str(ast_ctx.logged_exceptions[0]) == "decorated call failed"


@pytest.mark.asyncio
async def test_function_decorator_manager_exception_calls_result_handlers(hass):
"""When the decorated function raises, result handlers should be notified with None."""
DecoratorManager.hass = hass
ast_ctx = DummyAstCtx()
manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar())
result_handler = make_recording_result_handler()
manager.add(result_handler)
call_ast_ctx = DummyCallAstCtx(exc=RuntimeError("boom"))

with patch.object(Function, "store_hass_context"):
await call_function_manager(
manager,
make_dispatch_data(
{"arg1": 1},
call_ast_ctx=call_ast_ctx,
hass_context=Context(id="call-parent"),
),
)

assert result_handler.results == [None]
assert len(ast_ctx.logged_exceptions) == 1


def test_decorator_registry_register_requires_name():
"""Registry should reject decorators without a declared name."""

Expand Down
Loading
Loading