Skip to content

Commit dafed6a

Browse files
fix(event_handler): support finding type annotated resolver when merging schemas (#8074)
* Support finding type annotated resolver * Fix tests --------- Co-authored-by: Leandro Damascena <lcdama@amazon.pt>
1 parent 3f6fc29 commit dafed6a

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

aws_lambda_powertools/event_handler/openapi/merge.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,18 @@ def _file_has_resolver(file_path: Path, resolver_name: str) -> bool:
6767
return False
6868

6969
for node in ast.walk(tree):
70+
targets: list[ast.expr] = []
71+
value: ast.expr | None = None
7072
if isinstance(node, ast.Assign):
71-
for target in node.targets:
72-
if isinstance(target, ast.Name) and target.id == resolver_name:
73-
if _is_resolver_call(node.value):
74-
return True
73+
targets = node.targets
74+
value = node.value
75+
elif isinstance(node, ast.AnnAssign):
76+
targets = [node.target]
77+
value = node.value
78+
for target in targets:
79+
if isinstance(target, ast.Name) and target.id == resolver_name:
80+
if value is not None and _is_resolver_call(value):
81+
return True
7582
return False
7683

7784

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from pydantic import BaseModel
4+
5+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
6+
7+
app: APIGatewayRestResolver = APIGatewayRestResolver(enable_validation=True)
8+
9+
10+
class Product(BaseModel):
11+
id: int
12+
name: str
13+
price: float
14+
15+
16+
@app.get("/products")
17+
def get_products() -> list[Product]:
18+
return [
19+
Product(id=1, name="Widget", price=9.99),
20+
]
21+
22+
23+
def handler(event, context):
24+
return app.resolve(event, context)

tests/functional/event_handler/_pydantic/test_openapi_merge.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,19 @@ def test_openapi_merge_schema_is_cached():
367367

368368
# AND paths should not be duplicated
369369
assert len([p for p in schema1["paths"] if p == "/users"]) == 1
370+
371+
372+
def test_openapi_merge_discover_type_annotated_resolver():
373+
# GIVEN an OpenAPIMerge instance
374+
merge = OpenAPIMerge(title="Typed API", version="1.0.0")
375+
376+
# WHEN discovering a handler with a type-annotated resolver (app: Resolver = Resolver())
377+
merge.discover(
378+
path=MERGE_HANDLERS_PATH,
379+
pattern="**/typed_handler.py",
380+
resolver_name="app",
381+
)
382+
383+
# THEN it should find the resolver and include its routes in the schema
384+
schema = merge.get_openapi_schema()
385+
assert "/products" in schema["paths"]

0 commit comments

Comments
 (0)