diff --git a/micropython/bluetooth/aioble/aioble/security.py b/micropython/bluetooth/aioble/aioble/security.py index 8e04d5b7b..acf38aad3 100644 --- a/micropython/bluetooth/aioble/aioble/security.py +++ b/micropython/bluetooth/aioble/aioble/security.py @@ -5,7 +5,7 @@ import asyncio import binascii import json - +from . import core from .core import log_info, log_warn, ble, register_irq_handler from .device import DeviceConnection @@ -26,27 +26,58 @@ _DEFAULT_PATH = "ble_secrets.json" +# Maintain list of known keys, newest at the bottom / end. _secrets = {} _modified = False _path = None +# If set, limit the pairing db to this many peers +limit_peers = None + +SEC_TYPES_SELF = (10, ) +SEC_TYPES_PEER = (1, 2, 3, 4) + # Must call this before stack startup. def load_secrets(path=None): - global _path, _secrets + global _path, _secrets, limit_peers # Use path if specified, otherwise use previous path, otherwise use # default path. _path = path or _path or _DEFAULT_PATH # Reset old secrets. - _secrets = {} + _secrets.clear() try: with open(_path, "r") as f: entries = json.load(f) + # Newest entries at at the end, load them first for sec_type, key, value in entries: + if sec_type not in _secrets: + _secrets[sec_type] = [] # Decode bytes from hex. - _secrets[sec_type, binascii.a2b_base64(key)] = binascii.a2b_base64(value) + _secrets[sec_type].append((binascii.a2b_base64(key), binascii.a2b_base64(value))) + + if limit_peers: + # If we need to limit loaded keys, ensure the same addresses of each type are loaded + keep_keys = None + for sec_type in SEC_TYPES_PEER: + if sec_type not in _secrets: + continue + secrets = _secrets[sec_type] + if len(secrets) > limit_peers: + if not keep_keys: + keep_keys = [key for key, _ in secrets[-limit_peers:]] + log_warn("Limiting keys to", keep_keys) + + keep_entries = [entry for entry in secrets if entry[0] in keep_keys] + while len(keep_entries) < limit_peers: + for entry in reversed(secrets): + if entry not in keep_entries: + keep_entries.append(entry) + _secrets[sec_type] = keep_entries + _log_peers("loaded") + except: log_warn("No secrets available") @@ -61,17 +92,48 @@ def _save_secrets(arg=None): # Only save if the secrets changed. return + _log_peers('save_secrets') + with open(_path, "w") as f: # Convert bytes to hex strings (otherwise JSON will treat them like # strings). json_secrets = [ (sec_type, binascii.b2a_base64(key), binascii.b2a_base64(value)) - for (sec_type, key), value in _secrets.items() + for sec_type in _secrets for key, value in _secrets[sec_type] ] json.dump(json_secrets, f) _modified = False +def _remove_entry(sec_type, key): + secrets = _secrets[sec_type] + + # Delete existing secrets matching the type and key. + deleted = False + for to_delete in [ + entry for entry in secrets if entry[0] == key + ]: + log_info("Removing existing secret matching key") + secrets.remove(to_delete) + deleted = True + + return deleted + + +def _log_peers(heading=""): + if core.log_level <= 2: + return + log_info("secrets:", heading) + for sec_type in SEC_TYPES_PEER: + log_info("-", sec_type) + + if sec_type not in _secrets: + continue + secrets = _secrets[sec_type] + for key, value in secrets: + log_info(" - %s: %s..." % (key, value[0:16])) + + def _security_irq(event, data): global _modified @@ -90,20 +152,43 @@ def _security_irq(event, data): elif event == _IRQ_SET_SECRET: sec_type, key, value = data - key = sec_type, bytes(key) + key = bytes(key) value = bytes(value) if value else None - log_info("set secret:", key, value) - - if value is None: - # Delete secret. - if key not in _secrets: - return False - - del _secrets[key] - else: - # Save secret. - _secrets[key] = value + is_saving = value is not None + is_deleting = not is_saving + + if core.log_level > 2: + if is_deleting: + log_info("del secret:", key) + else: + shortval = value + if len(value) > 16: + shortval = value[0:16] + b"..." + log_info("set secret:", sec_type, key, shortval) + + if sec_type not in _secrets: + _secrets[sec_type] = [] + secrets = _secrets[sec_type] + + # Delete existing secrets matching the type and key. + removed = _remove_entry(sec_type, key) + + if is_deleting and not removed: + # Delete mode, but no entries were deleted + return False + + if is_saving: + # Save new secret. + if limit_peers and sec_type in SEC_TYPES_PEER and len(secrets) >= limit_peers: + addr, _ = secrets[0] + log_warn("Removing old peer to make space for new one") + ble.gap_unpair(addr) + log_info("Removed:", addr) + # Add new value to database + secrets.append((key, value)) + + _log_peers("set_secret") # Queue up a save (don't synchronously write to flash). _modified = True @@ -116,19 +201,23 @@ def _security_irq(event, data): log_info("get secret:", sec_type, index, bytes(key) if key else None) + secrets = _secrets.get(sec_type, []) if key is None: # Return the index'th secret of this type. - i = 0 - for (t, _key), value in _secrets.items(): - if t == sec_type: - if i == index: - return value - i += 1 + # This is used when loading "all" secrets at startup + if len(secrets) > index: + key, val = secrets[index] + return val + return None else: # Return the secret for this key (or None). - key = sec_type, bytes(key) - return _secrets.get(key, None) + key = bytes(key) + + for k, v in secrets: + if k == key: + return v + return None elif event == _IRQ_PASSKEY_ACTION: conn_handle, action, passkey = data