Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from functools import cached_property
from typing import Any
from typing import Literal
from typing import TypedDict
Expand Down Expand Up @@ -216,21 +217,37 @@ def __init__(
lifespan: bool = True,
) -> None:
self._app = app
self._region_name = region_name
self._endpoint_url = endpoint_url
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._loop = asyncio.get_event_loop()
self._client = boto3.client(
"sqs",
region_name=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
self._queue_url_cache = _QueueUrlCache(self._client)
self._send_batcher = _SendBatcher(self._client)
self._lifespan = lifespan
self._lifespan_context: Lifespan | None = None
self._state: dict[str, Any] = {}
self._client_instantiated = False
self._loop.add_signal_handler(signal.SIGTERM, self._sigterm_handler)

@cached_property
def _client(self) -> Any:
client = boto3.client(
"sqs",
region_name=self._region_name,
endpoint_url=self._endpoint_url,
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
)
self._client_instantiated = True
return client

@cached_property
def _queue_url_cache(self) -> _QueueUrlCache:
return _QueueUrlCache(self._client)

@cached_property
def _send_batcher(self) -> _SendBatcher:
return _SendBatcher(self._client)

def __call__(
self, event: _SqsEventSourceMapping, context: Any
) -> _BatchItemFailures:
Expand Down Expand Up @@ -286,5 +303,7 @@ def _sigterm_handler(self) -> None:
self._loop.run_until_complete(self._shutdown())

async def _shutdown(self) -> None:
if self._client_instantiated:
self._client.close()
if self._lifespan_context:
await self._lifespan_context.__aexit__(None, None, None)
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@
import base64
from collections.abc import AsyncGenerator
from collections.abc import Generator
from queue import Queue
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
from uuid import uuid4

import boto3
import pytest
from amgi_sqs_event_source_mapping import SqsHandler
from amgi_types import AMGIReceiveCallable
from amgi_types import AMGISendCallable
from amgi_types import Scope
from test_utils import MockApp


@pytest.fixture
def mock_sqs_client() -> Generator[Mock, None, None]:
with patch.object(boto3, "client") as mock_sqs_client:
yield mock_sqs_client
@pytest.fixture(autouse=True)
def mock_sqs_client() -> Generator[None, None, None]:
with patch.object(boto3, "client"):
yield


@pytest.fixture
async def app_sqs_handler(
mock_sqs_client: Mock,
) -> AsyncGenerator[tuple[MockApp, SqsHandler], None]:
async def app_sqs_handler() -> AsyncGenerator[tuple[MockApp, SqsHandler], None]:
app = MockApp()
sqs_handler = SqsHandler(app)

Expand Down Expand Up @@ -360,9 +363,7 @@ async def test_sqs_handler_record_corrupted(
}


async def test_lifespan(
mock_sqs_client: Mock,
) -> None:
async def test_lifespan() -> None:
app = MockApp()
sqs_handler = SqsHandler(app)

Expand Down Expand Up @@ -420,3 +421,77 @@ async def test_lifespan(
shutdown_task = loop.create_task(sqs_handler._shutdown())

await shutdown_task


def test_lifespan_and_shutdown() -> None:
queue = Queue[Exception | None]()

async def _app(
scope: Scope, receive: AMGIReceiveCallable, send: AMGISendCallable
) -> None:
try:
assert scope["type"] == "lifespan"
lifespan_startup = await receive()
assert lifespan_startup == {"type": "lifespan.startup"}
await send(
{
"type": "lifespan.startup.complete",
}
)
lifespan_shutdown = await receive()
assert lifespan_shutdown == {"type": "lifespan.shutdown"}
await send(
{
"type": "lifespan.shutdown.complete",
}
)
queue.put(None)
except Exception as e: # pragma: no cover
queue.put(e)
raise

sqs_handler = SqsHandler(_app)

sqs_handler({"Records": []}, Mock())

sqs_handler._sigterm_handler()

exception = queue.get()
assert exception is None


def test_sqs_handler_app_not_called_if_invalid_arn() -> None:
mock_app = AsyncMock()
sqs_handler = SqsHandler(mock_app, lifespan=False)
sqs_handler(
{
"Records": [
{
"messageId": "059f36b4-87a3-44ab-83d2-661975830a7d",
"receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a...",
"body": "Test message.",
"attributes": {
"ApproximateReceiveCount": "1",
"SentTimestamp": "1545082649183",
"SenderId": "AIDAIENQZJOLO23YVJ4VO",
"ApproximateFirstReceiveTimestamp": "1545082649185",
},
"messageAttributes": {
"myAttribute": {
"stringValue": "myValue",
"stringListValues": [],
"binaryListValues": [],
"dataType": "String",
}
},
"md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3",
"eventSource": "aws:sqs",
"eventSourceARN": "invalid-arn:aws:sqs:us-east-2:123456789012:my-queue",
"awsRegion": "us-east-2",
}
]
},
Mock(),
)

mock_app.assert_not_awaited()