Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/infuse_iot/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def _bt_adv_key(self, base: bytes, time_idx: int) -> bytes:
def _bt_gatt_key(self, base: bytes, time_idx: int) -> bytes:
return hkdf_derive(base, time_idx.to_bytes(4, "little"), b"bt_gatt")

def _udp_key(self, base: bytes, time_idx: int) -> bytes:
return hkdf_derive(base, time_idx.to_bytes(4, "little"), b"udp")

def has_public_key(self, address: int) -> bool:
"""Does the database have the public key for this device?"""
if address not in self.devices:
Expand Down Expand Up @@ -219,3 +222,27 @@ def bt_gatt_device_key(self, address: int, gps_time: int) -> bytes:
time_idx = gps_time // (60 * 60 * 24)

return self._bt_gatt_key(base, time_idx)

def udp_network_key(self, address: int, gps_time: int) -> bytes:
"""Network key for UDP interface"""
if address not in self.devices:
raise DeviceUnknownNetworkKey
network_id = self.devices[address].network_id
if network_id is None:
raise DeviceUnknownNetworkKey

return self._network_key(network_id, b"udp", gps_time)

def udp_device_key(self, address: int, gps_time: int) -> bytes:
"""Device key for UDP interface"""
if address not in self.devices:
raise DeviceUnknownDeviceKey
d = self.devices[address]
if d.device_id is None:
raise DeviceUnknownDeviceKey
base = self.devices[address].shared_key
if base is None:
raise DeviceUnknownDeviceKey
time_idx = gps_time // (60 * 60 * 24)

return self._udp_key(base, time_idx)
58 changes: 46 additions & 12 deletions src/infuse_iot/epacket/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,18 +337,8 @@ class CtypeForwardHeaderBtGatt(ctypes.LittleEndianStructure):
_pack_ = 1


class CtypeV0VersionedFrame(ctypes.LittleEndianStructure):
_fields_ = [
("version", ctypes.c_uint8),
("_type", ctypes.c_uint8),
("flags", ctypes.c_uint16),
("_key_metadata", ctypes.c_uint8 * 3),
("_device_id_upper", ctypes.c_uint32),
("_device_id_lower", ctypes.c_uint32),
("gps_time", ctypes.c_uint32),
("sequence", ctypes.c_uint16),
("entropy", ctypes.c_uint16),
]
class CtypeV0Frame(ctypes.LittleEndianStructure):
_fields_ = []
_pack_ = 1

@property
Expand Down Expand Up @@ -381,6 +371,35 @@ def parse(cls, frame: bytes) -> tuple[Self, int]:
)


class CtypeV0VersionedFrame(CtypeV0Frame):
_fields_ = [
("version", ctypes.c_uint8),
("_type", ctypes.c_uint8),
("flags", ctypes.c_uint16),
("_key_metadata", ctypes.c_uint8 * 3),
("_device_id_upper", ctypes.c_uint32),
("_device_id_lower", ctypes.c_uint32),
("gps_time", ctypes.c_uint32),
("sequence", ctypes.c_uint16),
("entropy", ctypes.c_uint16),
]
_pack_ = 1


class CtypeV0UnversionedFrame(CtypeV0Frame):
_fields_ = [
("_type", ctypes.c_uint8),
("flags", ctypes.c_uint16),
("_key_metadata", ctypes.c_uint8 * 3),
("_device_id_upper", ctypes.c_uint32),
("_device_id_lower", ctypes.c_uint32),
("gps_time", ctypes.c_uint32),
("sequence", ctypes.c_uint16),
("entropy", ctypes.c_uint16),
]
_pack_ = 1


class CtypeSerialFrame(CtypeV0VersionedFrame):
"""Serial packet header"""

Expand Down Expand Up @@ -484,6 +503,21 @@ def decrypt(cls, database: DeviceDatabase, bt_addr: Address.BluetoothLeAddr | No
return header, decrypted


class CtypeUdpFrame(CtypeV0UnversionedFrame):
@classmethod
def decrypt(cls, database: DeviceDatabase, frame: bytes):
header = cls.from_buffer_copy(frame)
if header.flags & Flags.ENCR_DEVICE:
database.observe_device(header.device_id, device_id=header.key_metadata)
key = database.udp_device_key(header.device_id, header.gps_time)
else:
database.observe_device(header.device_id, network_id=header.key_metadata)
key = database.udp_network_key(header.device_id, header.gps_time)

decrypted = chachapoly_decrypt(key, frame[:10], frame[10:22], frame[22:])
return header, decrypted


class CtypePacketReceived:
class CommonHeader(ctypes.Structure):
_fields_ = [
Expand Down
20 changes: 13 additions & 7 deletions src/infuse_iot/serial_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import time
from abc import ABCMeta, abstractmethod
from io import BufferedWriter

import pylink
import serial
Expand Down Expand Up @@ -43,7 +44,7 @@ def reconstructor(cls):

class SerialLike(metaclass=ABCMeta):
@abstractmethod
def open(self) -> None:
def open(self, timeout: float | None = None) -> None:
"""Open serial port"""

@abstractmethod
Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(self, serial_port, baudrate=115200):
# receivers (STM32) time to wake up on RX before real data arrives.
self._prefix = b"\x00\x00" if baudrate > 115200 else b""

def open(self):
def open(self, timeout: float | None = None):
self._ser.open()

def read_bytes(self, num) -> bytes:
Expand Down Expand Up @@ -108,20 +109,25 @@ def __str__(self) -> str:
class RttPort(SerialLike):
"""Segger RTT handling"""

def __init__(self, rtt_device):
def __init__(self, rtt_device: str, serial_number: str | None = None):
self._jlink = pylink.JLink()
self._name = rtt_device
self._modem_trace = None
self._serial_number = serial_number
self._modem_trace: BufferedWriter | None = None
self._modem_trace_buf = 0

def open(self):
self._jlink.open()
def open(self, timeout: float | None = None):
self._jlink.open(serial_no=self._serial_number)
self._jlink.set_tif(pylink.enums.JLinkInterfaces.SWD)
self._jlink.connect(self._name, 4000)
self._jlink.rtt_start()

end_time = time.time() + timeout if timeout else None

# Loop until JLink initialised properly
while True:
if end_time and time.time() > end_time:
raise TimeoutError("RTT port never initialised")
try:
num_up = self._jlink.rtt_get_num_up_buffers()
_num_down = self._jlink.rtt_get_num_down_buffers()
Expand Down Expand Up @@ -179,7 +185,7 @@ def __init__(self, target: str):
self._target = self._session.target
self._rtt = GenericRTTControlBlock(self._target)

def open(self):
def open(self, timeout: float | None = None):
self._session.open()
self._target.resume()
self._rtt.start()
Expand Down
11 changes: 2 additions & 9 deletions src/infuse_iot/tools/provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@
from infuse_iot.util.soc import nrf, soc, stm


class ProvisioningStruct(ctypes.LittleEndianStructure):
_fields_ = [
("device_id", ctypes.c_uint64),
]
_pack_ = 1


class SubCommand(InfuseCommand):
NAME = "provision"
HELP = "Provision device on Infuse Cloud"
Expand Down Expand Up @@ -171,8 +164,8 @@ def run(self):
assert isinstance(response.parsed.device_id, str)
# Compare current flash contents to desired flash contents
cloud_id = int(response.parsed.device_id, 16)
current_bytes = interface.read_provisioned_data(ctypes.sizeof(ProvisioningStruct))
desired = ProvisioningStruct(cloud_id)
current_bytes = interface.read_provisioned_data(ctypes.sizeof(interface.DefaultProvisioningStruct))
desired = interface.DefaultProvisioningStruct(cloud_id)
desired_bytes = bytes(desired)

if current_bytes == desired_bytes:
Expand Down
20 changes: 12 additions & 8 deletions src/infuse_iot/util/soc/nrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,22 @@ def __init__(self, snr: int | None):
self.snr = snr
devices = self._exec(["device-info"])
if len(devices) == 0:
sys.exit()
raise RuntimeError("No devices found")
devices_info = devices[0]["devices"]

if len(devices_info) > 1:
serials = ",".join([d["serialNumber"] for d in devices_info])
sys.exit(f"Multiple devices found without a SNR provided (Found: {serials})")
self.snr = devices_info[0]["serialNumber"]
self.device_info = devices_info[0]["deviceInfo"]
if snr is None:
if len(devices_info) > 1:
serials = ",".join([d["serialNumber"] for d in devices_info])
raise RuntimeError(f"Multiple devices found without a SNR provided (Found: {serials})")
self.snr = int(devices_info[0]["serialNumber"])
infos = [info for info in devices_info if int(info["serialNumber"]) == self.snr]
if len(infos) == 0:
raise RuntimeError(f"Devices with SNR {self.snr} not found")
self.device_info = infos[0]["deviceInfo"]
self.core_info = self._exec(["core-info"])
self.family = DEVICE_FAMILY_MAPPING[self.device_info["jlink"]["deviceFamily"]]
self.uicr_base = self.core_info[0]["devices"][0]["uicrAddress"]
self._soc_name = self.family.soc(self.device_info)
self.uicr_base: int = self.core_info[0]["devices"][0]["uicrAddress"]
self._soc_name: str = self.family.soc(self.device_info)

def _exec(self, args: list[str]):
jout_all = []
Expand Down
7 changes: 7 additions & 0 deletions src/infuse_iot/util/soc/soc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
#!/usr/bin/env

import ctypes
from abc import ABCMeta, abstractmethod


class ProvisioningInterface(metaclass=ABCMeta):
"Generic SoC provisioning interface"

class DefaultProvisioningStruct(ctypes.LittleEndianStructure):
_fields_ = [
("device_id", ctypes.c_uint64),
]
_pack_ = 1

@property
@abstractmethod
def soc_name(self) -> str:
Expand Down