Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
.ruff_cache/
.coverage
coverage.xml
dist/
__pycache__/
*.py[cod]
3 changes: 3 additions & 0 deletions aws_lambda_opentelemetry/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from aws_lambda_opentelemetry.trace.helpers import instrument_handler

__all__ = ["instrument_handler"]
159 changes: 159 additions & 0 deletions aws_lambda_opentelemetry/trace/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import base64
import enum
import gzip
import logging
import os
import threading
import zlib
from collections.abc import Sequence
from io import BytesIO
from typing import Any

from opentelemetry.exporter.otlp.proto.common.trace_encoder import encode_spans
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_COMPRESSION,
OTEL_EXPORTER_OTLP_TRACES_COMPRESSION,
)
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
SpanExporter,
SpanExportResult,
)
from uuid_utils import uuid7

logger = logging.getLogger(__name__)


class Compression(enum.Enum):
NoCompression = "none"
Deflate = "deflate"
Gzip = "gzip"

@classmethod
def from_env(cls) -> "Compression":
compression = (
os.getenv(
OTEL_EXPORTER_OTLP_TRACES_COMPRESSION,
os.getenv(OTEL_EXPORTER_OTLP_COMPRESSION, "none"),
)
.lower()
.strip()
)
return Compression(compression)


class Base64SpanSerializer:
def __init__(self, compression: Compression):
self._compression = compression

def serialize(self, spans: Sequence[ReadableSpan]) -> str:
encoded_spans = encode_spans(spans)
data = encoded_spans.SerializeToString()

if self._compression == Compression.Gzip:
gzip_data = BytesIO()
with gzip.GzipFile(fileobj=gzip_data, mode="w") as gzip_stream:
gzip_stream.write(data)
data = gzip_data.getvalue()
elif self._compression == Compression.Deflate:
data = zlib.compress(data)

compressed_serialized_spans = base64.b64encode(data)
return compressed_serialized_spans.decode("utf-8")


class SQSTraceExporter(SpanExporter):
"""
Implements OpenTelemetry SpanExporter interface
which can be used in combination with a SpanProcessor
to publish traces to Amazon SQS.

```
provider = TracerProvider()
processor = SimpleSpanProcessor(SQSTraceExporter())
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
```
"""

def __init__(
self,
queue_url: str,
sqs_client: Any,
compression: Compression | None = None,
) -> None:
self._compression = compression or Compression.from_env()
self._serializer = Base64SpanSerializer(self._compression)
self._queue_url = queue_url
self._sqs_client = sqs_client
self._shutdown_in_progress = threading.Event()
self._shutdown = False

def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
"""
Exports spans to SQS in batches when the batch size is reached.
"""
if self._shutdown:
logger.warning("Exporter already shutdown, ignoring batch")
return SpanExportResult.FAILURE

entries = []
for span in spans:
serialized_span = self._serializer.serialize([span])
id_ = str(span.context.span_id) if span.context else uuid7().hex
entries.append({"Id": id_, "MessageBody": serialized_span})

try:
self._sqs_client.send_message_batch(
QueueUrl=self._queue_url, Entries=entries
)
return SpanExportResult.SUCCESS
except Exception as exc:
logger.exception(f"Unexpected error exporting spans: {exc}")
return SpanExportResult.FAILURE

def shutdown(self) -> None:
"""Flush remaining spans before shutdown."""
if self._shutdown:
logger.warning("Exporter already shutdown, ignoring call")
return

self._shutdown = True
self._shutdown_in_progress.set()
self._sqs_client.close()

def force_flush(self, timeout_millis: int = 30000) -> bool:
"""Nothing is buffered in this exporter, so this method does nothing."""
return True


class SQSBatchSpanProcessor(BatchSpanProcessor):
"""
BatchSpanProcessor configured for SQS limits.

Automatically sets max_export_batch_size to 10 (SQS batch limit).

```
provider = TracerProvider()
exporter = SQSTraceExporter(queue_url="your-sqs-queue-url")
processor = SQSBatchSpanProcessor(exporter)
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
```
"""

MAX_SQS_BATCH_SIZE = 10

def __init__(
self,
span_exporter: SpanExporter,
max_export_batch_size: int = MAX_SQS_BATCH_SIZE,
**kwargs,
) -> None:
assert max_export_batch_size <= self.MAX_SQS_BATCH_SIZE
super().__init__(
span_exporter=span_exporter,
max_export_batch_size=max_export_batch_size,
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@
from aws_lambda_opentelemetry.utils import set_lambda_handler_attributes


def instrument_lambda_handler(**kwargs):
def instrument_handler(**kwargs):
"""
Decorate a Lambda handler function to automatically create and manage
an OpenTelemetry span for the function invocation.

Accepts all keyword arguments from Tracer.start_as_current_span():

:param name: Span name (defaults to function name if not provided)
:param kind: SpanKind (defaults to SERVER if not provided)
:param context: Parent span context
:param kind: SpanKind (defaults to SERVER if not provided)
:param attributes: Initial span attributes dict
:param links: Span links
:param start_time: Span start timestamp
:param record_exception: Whether to record exceptions (default True)
:param set_status_on_exception: Whether to set error status on exception (default True)
:param end_on_exit: Whether to end the span on exit (default True)
:return: The decorated handler function.
"""

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ classifiers = [
]
dependencies = [
"opentelemetry-api>=1.0.0",
"opentelemetry-exporter-otlp-proto-common>=1.0.0",
"opentelemetry-sdk>=1.0.0",
"uuid-utils>=0.12.0",
]

[dependency-groups]
dev = [
"boto3>=1.42.14",
"moto>=5.1.18",
"pytest>=9.0.2",
"pytest-cov>=7.0.0",
"ruff>=0.14.9",
Expand Down
Empty file added tests/test_trace/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions tests/test_trace/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import boto3
import pytest
from moto import mock_aws


@pytest.fixture
def mock_sqs_client():
with mock_aws():
sqs = boto3.client("sqs", region_name="us-east-1")
sqs.create_queue(QueueName="test-queue")
yield sqs
160 changes: 160 additions & 0 deletions tests/test_trace/test_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from unittest.mock import MagicMock

import pytest
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
SpanExportResult,
)

from aws_lambda_opentelemetry.trace.export import (
Base64SpanSerializer,
Compression,
SQSBatchSpanProcessor,
SQSTraceExporter,
)
from tests.utils import generate_span


class TestSpanSerializer:
@pytest.mark.parametrize(
"compression, expected_length",
[
(Compression.NoCompression, 276),
(Compression.Gzip, 248),
(Compression.Deflate, 232),
],
)
def test_base64_span_serializer(self, compression, expected_length):
serializer = Base64SpanSerializer(compression)
spans = [generate_span()]
result = serializer.serialize(spans)
assert isinstance(result, str)
assert len(result) == expected_length

@pytest.mark.parametrize(
"compression_name, expected_compression",
[
("gzip", Compression.Gzip),
("deflate", Compression.Deflate),
("none", Compression.NoCompression),
],
)
def test_compression_from_env_var(
self, monkeypatch, compression_name, expected_compression
):
monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_COMPRESSION", compression_name)
assert Compression.from_env() == expected_compression


class TestSqsTraceExporter:
QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/123456789/test-queue"

def test_export_when_shutdown_is_called(self, mock_sqs_client):
exporter = SQSTraceExporter(
queue_url=self.QUEUE_URL,
sqs_client=mock_sqs_client,
)
exporter.shutdown()

result = exporter.export([])
assert result == SpanExportResult.FAILURE

def test_export_sends_messages_to_sqs(self, mock_sqs_client):
exporter = SQSTraceExporter(
queue_url=self.QUEUE_URL,
sqs_client=mock_sqs_client,
)

spans = [generate_span() for _ in range(2)]
result = exporter.export(spans)
assert result == SpanExportResult.SUCCESS

response = mock_sqs_client.receive_message(
QueueUrl=self.QUEUE_URL,
MaxNumberOfMessages=10,
)
messages = response.get("Messages", [])
assert len(messages) == 2

def test_export_handles_sqs_client_exception(self):
exporter = SQSTraceExporter(
queue_url=self.QUEUE_URL,
sqs_client=object(),
)

spans = [generate_span() for _ in range(2)]
result = exporter.export(spans)

assert result == SpanExportResult.FAILURE

def test_export_shutdown(self, mock_sqs_client):
mock_sqs_client.close = MagicMock()
exporter = SQSTraceExporter(
queue_url=self.QUEUE_URL,
sqs_client=mock_sqs_client,
)

exporter.shutdown()

mock_sqs_client.close.assert_called_once()
assert exporter._shutdown is True
assert exporter._shutdown_in_progress.is_set()

def test_export_shutdown_successive_calls(self, mock_sqs_client):
mock_sqs_client.close = MagicMock()
exporter = SQSTraceExporter(
queue_url=self.QUEUE_URL,
sqs_client=mock_sqs_client,
)

exporter.shutdown()
exporter.shutdown()

mock_sqs_client.close.assert_called_once()
assert exporter._shutdown is True
assert exporter._shutdown_in_progress.is_set()

def test_export_force_flush(self, mock_sqs_client):
exporter = SQSTraceExporter(
queue_url=self.QUEUE_URL,
sqs_client=mock_sqs_client,
)

result = exporter.force_flush()
assert result is True


class TestSqsBatchSpanProcessor:
QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/123456789/test-queue"

def test_sqs_batch_span_processor_exports_in_batches(self, mock_sqs_client):
exporter = SQSTraceExporter(
queue_url=self.QUEUE_URL,
sqs_client=mock_sqs_client,
)
processor = SQSBatchSpanProcessor(span_exporter=exporter)
provider = TracerProvider()
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)

tracer = trace.get_tracer("test-sqs-batch-span-processor")
for i in range(15):
with tracer.start_as_current_span(f"test-span-{i}"):
...

response = mock_sqs_client.receive_message(
QueueUrl=self.QUEUE_URL,
MaxNumberOfMessages=10,
WaitTimeSeconds=1,
)
assert len(response.get("Messages", [])) == 10

processor.shutdown()

response = mock_sqs_client.receive_message(
QueueUrl=self.QUEUE_URL,
MaxNumberOfMessages=10,
WaitTimeSeconds=1,
)
assert len(response.get("Messages", [])) == 5
Loading