Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/infuse_iot/tools/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def __init__(
self.ddb = ddb
self.rpc = rpc_server

def notification_broadcast(self, notification: ClientNotification):
if self.server:
self.server.broadcast(notification)

def query_device_key(self, cb_event: threading.Event | None = None):
def security_state_done(pkt: PacketReceived, _: int, response: bytes):
cloud_key = response[:32]
Expand Down Expand Up @@ -204,7 +208,7 @@ def _handle_local_tdf(self, pkt: PacketReceived):
if_addr = interface.Address.BluetoothLeAddr.from_tdf_struct(reading.data[0].address)
infuse_id = self._common.ddb.infuse_id_from_bluetooth(if_addr)
if infuse_id:
self._common.server.broadcast(ClientNotificationConnectionDropped(infuse_id))
self._common.notification_broadcast(ClientNotificationConnectionDropped(infuse_id))

def _handle_serial_frame(self, frame: bytearray):
try:
Expand Down Expand Up @@ -244,10 +248,9 @@ def _handle_serial_frame(self, frame: bytearray):
elif pkt.ptype == InfuseType.KEY_IDS:
self._common.query_device_key(None)

# Forward to clients
notification = ClientNotificationEpacketReceived(pkt)
if self._common.server:
# Forward to clients
self._common.server.broadcast(notification)
self._common.notification_broadcast(notification)
except (ValueError, KeyError) as e:
print(f"Decode failed ({e})")

Expand Down Expand Up @@ -322,19 +325,19 @@ def _bt_connect_cb(self, pkt: PacketReceived, rc: int, response: bytes):
else:
self._connected[infuse_id] = 1
rsp = ClientNotificationConnectionCreated(infuse_id, 244 - ctypes.sizeof(CtypeBtGattFrame) - 16)
self._common.server.broadcast(rsp)
self._common.notification_broadcast(rsp)

def _handle_conn_request(self, req: GatewayRequestConnectionRequest):
assert self._common.server is not None

if req.infuse_id == InfuseID.GATEWAY or req.infuse_id == self._common.ddb.gateway:
# Local gateway always connected
self._common.server.broadcast(ClientNotificationConnectionCreated(req.infuse_id, 512))
self._common.notification_broadcast(ClientNotificationConnectionCreated(req.infuse_id, 512))
return

state = self._common.ddb.devices.get(req.infuse_id, None)
if state is None or state.bt_addr is None:
self._common.server.broadcast(ClientNotificationConnectionFailed(req.infuse_id))
self._common.notification_broadcast(ClientNotificationConnectionFailed(req.infuse_id))
return

subs = 0
Expand Down Expand Up @@ -403,7 +406,7 @@ def _handle_observed_devices(self):
if self._common.ddb.gateway == device:
info["gateway"] = True
observed_devices[device] = info
self._common.server.broadcast(ClientNotificationObservedDevices(observed_devices))
self._common.notification_broadcast(ClientNotificationObservedDevices(observed_devices))

def _iter(self) -> None:
if self._common.server is None:
Expand Down
28 changes: 23 additions & 5 deletions src/infuse_iot/tools/native_bt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PacketReceived,
)
from infuse_iot.socket_comms import (
ClientNotification,
ClientNotificationConnectionCreated,
ClientNotificationConnectionFailed,
ClientNotificationEpacketReceived,
Expand Down Expand Up @@ -64,6 +65,12 @@ def __init__(self, database: DeviceDatabase, server: LocalServer, bleak_mapping:
self._queues: dict[int, asyncio.Queue] = {}
self._tasks: dict[int, asyncio.Task] = {}

def wrapped_broadcast(self, notifcation: ClientNotification):
try:
self._server.broadcast(notifcation)
except OSError as e:
Console.log_error(f"Failed to broadcast notification: {str(e)}")

def notification_handler(self, _characteristic: BleakGATTCharacteristic, data: bytearray):
try:
hdr, decr = CtypeBtGattFrame.decrypt(self._db, None, bytes(data))
Expand All @@ -89,9 +96,11 @@ def notification_handler(self, _characteristic: BleakGATTCharacteristic, data: b
bytes(decr),
)
Console.log_rx(pkt.ptype, len(data))
self._server.broadcast(ClientNotificationEpacketReceived(pkt))
self.wrapped_broadcast(ClientNotificationEpacketReceived(pkt))

async def create_connection(self, request: GatewayRequestConnectionRequest, dev: BLEDevice, queue: asyncio.Queue):
async def create_connection_internal(
self, request: GatewayRequestConnectionRequest, dev: BLEDevice, queue: asyncio.Queue
):
Console.log_info(f"{dev}: Initiating connection")
async with BleakClient(dev, timeout=request.timeout_ms / 1000) as client:
# Modified from bleak example code
Expand All @@ -116,7 +125,7 @@ async def create_connection(self, request: GatewayRequestConnectionRequest, dev:

Console.log_info(f"{dev}: Connected (MTU {client.mtu_size})")

self._server.broadcast(
self.wrapped_broadcast(
ClientNotificationConnectionCreated(
request.infuse_id,
# ATT header uses 3 bytes of the MTU
Expand Down Expand Up @@ -147,6 +156,12 @@ async def create_connection(self, request: GatewayRequestConnectionRequest, dev:
self._tasks.pop(request.infuse_id)
Console.log_info(f"{dev}: Terminating connection")

async def create_connection(self, request: GatewayRequestConnectionRequest, dev: BLEDevice, queue: asyncio.Queue):
try:
await self.create_connection_internal(request, dev, queue)
except TimeoutError as e:
Console.log_info(f"Timeout: {str(e)}")

def datagram_received(self, data: bytes, addr: tuple[str | Any, int]):
loop = asyncio.get_event_loop()
request = GatewayRequest.from_json(json.loads(data.decode("utf-8")))
Expand All @@ -166,7 +181,7 @@ def datagram_received(self, data: bytes, addr: tuple[str | Any, int]):

ble_dev = self._mapping.get(request.infuse_id, None)
if ble_dev is None:
self._server.broadcast(ClientNotificationConnectionFailed(request.infuse_id))
self.wrapped_broadcast(ClientNotificationConnectionFailed(request.infuse_id))
return

# Create queue for further data transfer
Expand Down Expand Up @@ -243,7 +258,10 @@ def simple_callback(self, device: BLEDevice, data: AdvertisementData):
Console.log_rx(hdr.type, len(payload))
pkt = PacketReceived([hop], hdr.type, decr)
notification = ClientNotificationEpacketReceived(pkt)
self.server.broadcast(notification)
try:
self.server.broadcast(notification)
except OSError as e:
Console.log_error(f"Failed to broadcast notification: {str(e)}")

async def async_bt_receiver(self):
loop = asyncio.get_event_loop()
Expand Down