Skip to content
Open
12 changes: 10 additions & 2 deletions awscrt/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from awscrt.http import HttpProxyOptions, HttpRequest
from awscrt.io import ClientBootstrap, ClientTlsContext, SocketOptions
from dataclasses import dataclass
from awscrt.mqtt5 import Client as Mqtt5Client
from awscrt.mqtt5 import Client as Mqtt5Client, _get_awsiot_metrics_str


class QoS(IntEnum):
Expand Down Expand Up @@ -330,6 +330,8 @@ class Connection(NativeResource):

proxy_options (Optional[awscrt.http.HttpProxyOptions]):
Optional proxy options for all connections.

enable_aws_metrics (bool): If true, append AWS IoT metrics to the username. (Default to true)
"""

def __init__(self,
Expand All @@ -355,7 +357,8 @@ def __init__(self,
proxy_options=None,
on_connection_success=None,
on_connection_failure=None,
on_connection_closed=None
on_connection_closed=None,
enable_aws_metrics=True
):

assert isinstance(client, Client) or isinstance(client, Mqtt5Client)
Expand Down Expand Up @@ -404,6 +407,11 @@ def __init__(self,
self.ping_timeout_ms = ping_timeout_ms
self.protocol_operation_timeout_ms = protocol_operation_timeout_ms
self.will = will

if enable_aws_metrics:
username = username if username else ""
username += _get_awsiot_metrics_str(username)

self.username = username
self.password = password
self.socket_options = socket_options if socket_options else SocketOptions()
Expand Down
45 changes: 45 additions & 0 deletions awscrt/mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,43 @@
from collections.abc import Sequence
from inspect import signature

# Global variable to cache metrics string
_sdk_str = None
_platform_str = None


def _get_awsiot_metrics_str(current_username=""):
global _sdk_str
global _platform_str

_metrics_str = ""
if _sdk_str is None:
try:
import importlib.metadata
try:
version = importlib.metadata.version("awscrt")
_sdk_str = "SDK=CRTPython&Version={}".format(version)
except importlib.metadata.PackageNotFoundError:
_sdk_str = "SDK=CRTPython&Version=dev"
except BaseException:
_sdk_str = ""

if _platform_str is None:
_platform_str = "Platform={}".format(_awscrt.get_platform_build_os_string())

if current_username.find("SDK=") == -1:
_metrics_str += _sdk_str
if current_username.find("Platform=") == -1:
_metrics_str += ("&" if len(_metrics_str) > 0 else "") + _platform_str

if not _metrics_str == "":
if current_username.find("?") != -1:
return "&" + _metrics_str
else:
return "?" + _metrics_str
else:
return ""


class QoS(IntEnum):
"""MQTT message delivery quality of service.
Expand Down Expand Up @@ -1338,6 +1375,7 @@ class ClientOptions:
on_lifecycle_event_connection_success_fn (Callable[[LifecycleConnectSuccessData],]): Callback for Lifecycle Event Connection Success.
on_lifecycle_event_connection_failure_fn (Callable[[LifecycleConnectFailureData],]): Callback for Lifecycle Event Connection Failure.
on_lifecycle_event_disconnection_fn (Callable[[LifecycleDisconnectData],]): Callback for Lifecycle Event Disconnection.
enable_aws_metrics (bool): Whether to append AWS IoT metrics to the username field during CONNECT. Default: True
"""
host_name: str
port: int = None
Expand All @@ -1364,6 +1402,7 @@ class ClientOptions:
on_lifecycle_event_connection_success_fn: Callable[[LifecycleConnectSuccessData], None] = None
on_lifecycle_event_connection_failure_fn: Callable[[LifecycleConnectFailureData], None] = None
on_lifecycle_event_disconnection_fn: Callable[[LifecycleDisconnectData], None] = None
enable_aws_metrics: bool = True


def _check_callback(callback):
Expand Down Expand Up @@ -1753,6 +1792,12 @@ def __init__(self, client_options: ClientOptions):
is_will_none = False
will = connect_options.will

username = connect_options.username
if client_options.enable_aws_metrics:
username = username if username else ""
username += _get_awsiot_metrics_str(username)

connect_options.username = username
websocket_is_none = client_options.websocket_handshake_transform is None
self.tls_ctx = client_options.tls_ctx
self._binding = _awscrt.mqtt5_client_new(self,
Expand Down
8 changes: 8 additions & 0 deletions source/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ PyObject *aws_py_get_cpu_count_for_group(PyObject *self, PyObject *args) {
return PyLong_FromSize_t(count);
}

PyObject *aws_py_get_platform_build_os_string(PyObject *self, PyObject *args) {
(void)self;
(void)args;

struct aws_byte_cursor os_string = aws_get_platform_build_os_string();
return PyUnicode_FromAwsByteCursor(&os_string);
}

PyObject *aws_py_thread_join_all_managed(PyObject *self, PyObject *args) {
(void)self;

Expand Down
1 change: 1 addition & 0 deletions source/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

PyObject *aws_py_get_cpu_group_count(PyObject *self, PyObject *args);
PyObject *aws_py_get_cpu_count_for_group(PyObject *self, PyObject *args);
PyObject *aws_py_get_platform_build_os_string(PyObject *self, PyObject *args);

PyObject *aws_py_thread_join_all_managed(PyObject *self, PyObject *args);

Expand Down
1 change: 1 addition & 0 deletions source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ static PyMethodDef s_module_methods[] = {
AWS_PY_METHOD_DEF(get_corresponding_builtin_exception, METH_VARARGS),
AWS_PY_METHOD_DEF(get_cpu_group_count, METH_VARARGS),
AWS_PY_METHOD_DEF(get_cpu_count_for_group, METH_VARARGS),
AWS_PY_METHOD_DEF(get_platform_build_os_string, METH_VARARGS),
AWS_PY_METHOD_DEF(native_memory_usage, METH_NOARGS),
AWS_PY_METHOD_DEF(native_memory_dump, METH_NOARGS),
AWS_PY_METHOD_DEF(thread_join_all_managed, METH_VARARGS),
Expand Down
6 changes: 4 additions & 2 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ def _test_mqtt311_direct_connect_basic_auth(self):
host_name=input_host_name,
port=input_port,
username=input_username,
password=input_password)
password=input_password,
enable_aws_metrics=False) # Disable AWS metrics for basic auth on non-AWS broker
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

Expand Down Expand Up @@ -760,7 +761,8 @@ def sign_function(transform_args, **kwargs):
username=input_username,
password=input_password,
use_websockets=True,
websocket_handshake_transform=sign_function)
websocket_handshake_transform=sign_function,
enable_aws_metrics=False) # Disable AWS metrics for basic auth on non-AWS broker
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

Expand Down
9 changes: 6 additions & 3 deletions test/test_mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def _test_direct_connect_basic_auth(self):
client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=input_port,
connect_options=connect_options
connect_options=connect_options,
enable_aws_metrics=False # Disable AWS metrics for basic auth on non-AWS broker
)
callbacks = Mqtt5TestCallbacks()
client = self._create_client(client_options=client_options, callbacks=callbacks)
Expand Down Expand Up @@ -416,7 +417,8 @@ def _test_websocket_connect_basic_auth(self):
client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=input_port,
connect_options=connect_options
connect_options=connect_options,
enable_aws_metrics=False # Disable AWS metrics for basic auth on non-AWS broker
)
callbacks = Mqtt5TestCallbacks()
client_options.websocket_handshake_transform = callbacks.ws_handshake_transform
Expand Down Expand Up @@ -615,7 +617,8 @@ def test_connect_with_incorrect_basic_authentication_credentials(self):
client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=input_port,
connect_options=connect_options
connect_options=connect_options,
enable_aws_metrics=False # Disable AWS metrics for basic auth on non-AWS broker
)
callbacks = Mqtt5TestCallbacks()
client = self._create_client(client_options=client_options, callbacks=callbacks)
Expand Down
Loading