import json import os import threading from typing import Optional class KeyStore: """Thread-safe per-device master key storage backed by a JSON file. File format (keys.json): { "device_id_1": "hex_encoded_32_byte_master_key", "device_id_2": "hex_encoded_32_byte_master_key", ... } """ def __init__(self, path: str = "keys.json"): self.path = os.path.abspath(path) self._keys: dict[str, bytes] = {} self._lock = threading.Lock() self.load() def load(self) -> None: """Load keys from JSON file. Creates empty file if not found.""" if not os.path.exists(self.path): return try: with open(self.path, "r") as f: data = json.load(f) with self._lock: self._keys = { did: bytes.fromhex(hex_key) for did, hex_key in data.items() } except (json.JSONDecodeError, ValueError) as e: print(f"[KeyStore] Warning: failed to load {self.path}: {e}") def save(self) -> None: """Persist keys to JSON file.""" with self._lock: data = { did: key.hex() for did, key in self._keys.items() } with open(self.path, "w") as f: json.dump(data, f, indent=2) def get(self, device_id: str) -> Optional[bytes]: """Return 32-byte master key for a device, or None if unknown. Auto-reloads from disk if the key is not found (hot-reload support).""" with self._lock: key = self._keys.get(device_id) if key is None: # Key not in memory — try reloading from disk (deploy.py may have added it) self.load() with self._lock: key = self._keys.get(device_id) return key def add(self, device_id: str, master_key: bytes) -> None: """Register a device's master key and persist.""" if len(master_key) != 32: raise ValueError(f"master_key must be 32 bytes, got {len(master_key)}") with self._lock: self._keys[device_id] = master_key self.save() def remove(self, device_id: str) -> bool: """Remove a device's key. Returns True if it existed.""" with self._lock: if device_id in self._keys: del self._keys[device_id] self.save() return True return False def list_devices(self) -> list[str]: """Return list of known device IDs.""" with self._lock: return list(self._keys.keys()) def __len__(self) -> int: with self._lock: return len(self._keys) def __contains__(self, device_id: str) -> bool: with self._lock: return device_id in self._keys