Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aws_lambda_opentelemetry/trace/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)

from aws_lambda_opentelemetry.typing.context import LambdaContext
from aws_lambda_opentelemetry.utils import set_lambda_handler_attributes
from aws_lambda_opentelemetry.utils import set_handler_attributes


def instrument_handler(**kwargs):
Expand Down Expand Up @@ -54,7 +54,7 @@ def wrapper(event: dict, context: LambdaContext):
span.record_exception(exc)
raise
finally:
set_lambda_handler_attributes(event, context, span)
set_handler_attributes(event, context, span)
finally:
provider.force_flush()

Expand Down
116 changes: 88 additions & 28 deletions aws_lambda_opentelemetry/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import os

from opentelemetry.semconv._incubating.attributes.cloud_attributes import (
Expand All @@ -15,6 +16,13 @@
FaasInvokedProviderValues,
FaasTriggerValues,
)
from opentelemetry.semconv._incubating.attributes.messaging_attributes import (
MESSAGING_BATCH_MESSAGE_COUNT,
MESSAGING_DESTINATION_NAME,
MESSAGING_OPERATION,
MESSAGING_SYSTEM,
MessagingOperationTypeValues,
)
from opentelemetry.trace import Span

from aws_lambda_opentelemetry import constants
Expand All @@ -23,11 +31,28 @@
_is_cold_start = True


def set_lambda_handler_attributes(event: dict, context: LambdaContext, span: Span):
class AwsDataSource(enum.Enum):
API_GATEWAY = "aws.api_gateway"
HTTP_API = "aws.http_api"
ELB = "aws.elb"
SQS = "aws.sqs"
SNS = "aws.sns"
S3 = "aws.s3"
DYNAMODB = "aws.dynamodb"
KINESIS = "aws.kinesis"
EVENT_BRIDGE = "aws.event_bridge"
CLOUDWATCH_LOGS = "aws.cloudwatch_logs"
OTHER = "aws.other"


def set_handler_attributes(event: dict, context: LambdaContext, span: Span):
"""
Set standard AWS Lambda attributes on the given span.
"""

data_source_mapper = DataSourceAttributeMapper(event)

span.set_attributes(data_source_mapper.attributes)
span.set_attributes(
{
FAAS_INVOCATION_ID: context.aws_request_id,
Expand All @@ -37,45 +62,80 @@ def set_lambda_handler_attributes(event: dict, context: LambdaContext, span: Spa
FAAS_MAX_MEMORY: context.memory_limit_in_mb,
FAAS_VERSION: context.function_version,
FAAS_COLDSTART: _check_cold_start(),
FAAS_TRIGGER: get_lambda_datasource(event).value,
FAAS_TRIGGER: data_source_mapper.faas_trigger.value,
CLOUD_RESOURCE_ID: context.invoked_function_arn,
}
)


def get_lambda_datasource(event: dict) -> FaasTriggerValues:
"""
Extract the data source from the Lambda event.
"""
class DataSourceAttributeMapper:
def __init__(self, event: dict):
self.event = event
self.data_source, self.faas_trigger = self.get_sources()

@property
def attributes(self) -> dict:
if self.data_source == AwsDataSource.SQS:
return self._get_sqs_attributes()
return {}

def get_sources(self) -> tuple[AwsDataSource, FaasTriggerValues]:
# HTTP triggers
if "requestContext" in self.event:
if "apiId" in self.event["requestContext"]:
return (AwsDataSource.API_GATEWAY, FaasTriggerValues.HTTP)

# HTTP triggers
http_keys = ["apiId", "http", "elb"]
if "requestContext" in event:
if any(key in event["requestContext"] for key in http_keys):
return FaasTriggerValues.HTTP
if "http" in self.event["requestContext"]:
return (AwsDataSource.HTTP_API, FaasTriggerValues.HTTP)

# EventBridge
if "source" in event and "detail-type" in event:
if event["detail-type"] == "Scheduled Event":
return FaasTriggerValues.TIMER
return FaasTriggerValues.PUBSUB
if "elb" in self.event["requestContext"]:
return (AwsDataSource.ELB, FaasTriggerValues.HTTP)

# SNS/SQS/S3/DynamoDB/Kinesis
if "Records" in event and len(event["Records"]) > 0:
record = event["Records"][0]
event_source = record.get("eventSource")
# EventBridge
if "source" in self.event and "detail-type" in self.event:
if self.event["detail-type"] == "Scheduled Event":
return (AwsDataSource.EVENT_BRIDGE, FaasTriggerValues.TIMER)
return (AwsDataSource.EVENT_BRIDGE, FaasTriggerValues.PUBSUB)

if event_source in {"aws:sns", "aws:sqs"}:
return FaasTriggerValues.PUBSUB
# SNS/SQS/S3/DynamoDB/Kinesis
if "Records" in self.event and len(self.event["Records"]) > 0:
record = self.event["Records"][0]
event_source = record.get("eventSource")

if event_source in {"aws:s3", "aws:dynamodb", "aws:kinesis"}:
return FaasTriggerValues.DATASOURCE
if event_source == "aws:sns":
return (AwsDataSource.SNS, FaasTriggerValues.PUBSUB)

# CloudWatch Logs
if "awslogs" in event and "data" in event["awslogs"]:
return FaasTriggerValues.DATASOURCE
if event_source == "aws:sqs":
return (AwsDataSource.SQS, FaasTriggerValues.PUBSUB)

return FaasTriggerValues.OTHER
if event_source == "aws:s3":
return (AwsDataSource.S3, FaasTriggerValues.DATASOURCE)

if event_source == "aws:dynamodb":
return (AwsDataSource.DYNAMODB, FaasTriggerValues.DATASOURCE)

if event_source == "aws:kinesis":
return (AwsDataSource.KINESIS, FaasTriggerValues.DATASOURCE)

# CloudWatch Logs
if "awslogs" in self.event and "data" in self.event["awslogs"]:
return (AwsDataSource.CLOUDWATCH_LOGS, FaasTriggerValues.DATASOURCE)

return (AwsDataSource.OTHER, FaasTriggerValues.OTHER)

def _get_sqs_attributes(self) -> dict:
records = self.event.get("Records", [])
message_count = len(records)
queue_arn = records[0].get("eventSourceARN", "") if message_count > 0 else ""
queue_name = queue_arn.split(":")[-1]

return {
MESSAGING_SYSTEM: self.data_source.value,
MESSAGING_OPERATION: MessagingOperationTypeValues.RECEIVE.value,
MESSAGING_BATCH_MESSAGE_COUNT: message_count,
MESSAGING_DESTINATION_NAME: queue_name,
CLOUD_RESOURCE_ID: queue_arn,
}


def _check_cold_start() -> bool:
Expand Down
104 changes: 75 additions & 29 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
import pytest
from opentelemetry.sdk.trace import Span
from opentelemetry.semconv._incubating.attributes.faas_attributes import (
# FAAS_COLDSTART,
# FAAS_INVOCATION_ID,
# FAAS_INVOKED_NAME,
# FAAS_INVOKED_PROVIDER,
# FAAS_INVOKED_REGION,
# FAAS_MAX_MEMORY,
# FAAS_TRIGGER,
# FAAS_VERSION,
# FaasInvokedProviderValues,
FaasTriggerValues,
)

Expand Down Expand Up @@ -43,16 +34,23 @@ def test_cold_start_provisioned_concurrency(self, monkeypatch):

class TestLambdaDataSource:
@pytest.mark.parametrize(
"key",
[("apiId",), ("httpMethod",), ("elb",)],
"key,aws_data_source",
[
("apiId", utils.AwsDataSource.API_GATEWAY),
("http", utils.AwsDataSource.HTTP_API),
("elb", utils.AwsDataSource.ELB),
],
)
def test_http_trigger(self, key: str):
def test_http_trigger(self, key: str, aws_data_source: utils.AwsDataSource):
event = {
"requestContext": {
"apiId": "example-api-id",
key: "example-api-id",
}
}
assert utils.get_lambda_datasource(event) == utils.FaasTriggerValues.HTTP

mapper = utils.DataSourceAttributeMapper(event)
assert mapper.faas_trigger == utils.FaasTriggerValues.HTTP
assert mapper.data_source == aws_data_source

@pytest.mark.parametrize(
"detail_type, expected",
Expand All @@ -66,49 +64,73 @@ def test_eventbridge_trigger(self, detail_type: str, expected: FaasTriggerValues
"source": "aws.events",
"detail-type": detail_type,
}
assert utils.get_lambda_datasource(event) == expected

mapper = utils.DataSourceAttributeMapper(event)
assert mapper.faas_trigger == expected
assert mapper.data_source == utils.AwsDataSource.EVENT_BRIDGE

@pytest.mark.parametrize(
"event_source, expected",
"event_source, aws_data_source, faas_trigger",
[
("aws:sns", utils.FaasTriggerValues.PUBSUB),
("aws:sqs", utils.FaasTriggerValues.PUBSUB),
("aws:s3", utils.FaasTriggerValues.DATASOURCE),
("aws:dynamodb", utils.FaasTriggerValues.DATASOURCE),
("aws:kinesis", utils.FaasTriggerValues.DATASOURCE),
("aws:sns", utils.AwsDataSource.SNS, utils.FaasTriggerValues.PUBSUB),
("aws:sqs", utils.AwsDataSource.SQS, utils.FaasTriggerValues.PUBSUB),
("aws:s3", utils.AwsDataSource.S3, utils.FaasTriggerValues.DATASOURCE),
(
"aws:dynamodb",
utils.AwsDataSource.DYNAMODB,
utils.FaasTriggerValues.DATASOURCE,
),
(
"aws:kinesis",
utils.AwsDataSource.KINESIS,
utils.FaasTriggerValues.DATASOURCE,
),
],
)
def test_pubsub_trigger(self, event_source: str, expected: FaasTriggerValues):
def test_pubsub_trigger(
self,
event_source: str,
aws_data_source: utils.AwsDataSource,
faas_trigger: FaasTriggerValues,
):
event = {
"Records": [
{
"eventSource": event_source,
}
]
}
assert utils.get_lambda_datasource(event) == expected

mapper = utils.DataSourceAttributeMapper(event)
assert mapper.faas_trigger == faas_trigger
assert mapper.data_source == aws_data_source

def test_cloudwatch_logs_trigger(self):
event = {
"awslogs": {
"data": "example-data",
}
}
assert utils.get_lambda_datasource(event) == utils.FaasTriggerValues.DATASOURCE

mapper = utils.DataSourceAttributeMapper(event)
assert mapper.faas_trigger == utils.FaasTriggerValues.DATASOURCE
assert mapper.data_source == utils.AwsDataSource.CLOUDWATCH_LOGS

def test_unknown_trigger(self):
event = {}
assert utils.get_lambda_datasource(event) == utils.FaasTriggerValues.OTHER

mapper = utils.DataSourceAttributeMapper(event)
assert mapper.faas_trigger == utils.FaasTriggerValues.OTHER
assert mapper.data_source == utils.AwsDataSource.OTHER


class TestSetLambdaHandlerAttributes:
def test_set_attributes(self, lambda_context: LambdaContext):
def test_set_general_attributes(self, lambda_context: LambdaContext):
span = MagicMock(spec=Span)

utils.set_lambda_handler_attributes({}, lambda_context, span)
utils.set_handler_attributes({}, lambda_context, span)

span.set_attributes.assert_called_once()
attributes = span.set_attributes.call_args[0][0]
attributes = span.set_attributes.call_args_list[1][0][0]
assert attributes["faas.invocation_id"] == lambda_context.aws_request_id
assert attributes["faas.invoked_name"] == lambda_context.function_name
assert attributes["faas.invoked_region"] == lambda_context.region
Expand All @@ -118,3 +140,27 @@ def test_set_attributes(self, lambda_context: LambdaContext):
assert attributes["faas.coldstart"] is False
assert attributes["faas.trigger"] == "other"
assert attributes["cloud.resource_id"] == lambda_context.invoked_function_arn

def test_sqs_attributes_set(self, lambda_context: LambdaContext):
span = MagicMock(spec=Span)

event = {
"Records": [
{
"eventSource": "aws:sqs",
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue",
"awsRegion": "us-east-1",
}
]
}

utils.set_handler_attributes(event, lambda_context, span)

attributes = span.set_attributes.call_args_list[0][0][0]
assert attributes["messaging.system"] == "aws.sqs"
assert attributes["messaging.destination.name"] == "queue"
assert attributes["messaging.operation"] == "receive"
assert (
attributes["cloud.resource_id"]
== "arn:aws:sqs:us-east-1:123456789012:queue"
)