Skip to content

Commit 9ccb258

Browse files
committed
feat: Add an aiomqtt based MQTT session module
1 parent ba422aa commit 9ccb258

File tree

8 files changed

+442
-8
lines changed

8 files changed

+442
-8
lines changed

poetry.lock

Lines changed: 16 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ paho-mqtt = ">=1.6.1,<3.0.0"
3131
construct = "^2.10.57"
3232
vacuum-map-parser-roborock = "*"
3333
pyrate-limiter = "^3.7.0"
34+
aiomqtt = "^2.3.2"
3435

3536

3637
[build-system]

roborock/mqtt/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""This module contains the low level MQTT client for the Roborock vacuum cleaner.
2+
3+
This is not meant to be used directly, but rather as a base for the higher level
4+
modules.
5+
"""
6+
7+
__all__: list[str] = []

roborock/mqtt/roborock_session.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""An MQTT session for sending and receiving messages.
2+
3+
See create_mqtt_session for a factory function to create an MQTT session.
4+
5+
This is a thin wrapper around the async MQTT client that handles dispatching messages
6+
from a topic to a callback function, since the async MQTT client does not
7+
support this out of the box. It also handles the authentication process and
8+
receiving messages from the vacuum cleaner.
9+
"""
10+
11+
import asyncio
12+
import datetime
13+
import logging
14+
from collections.abc import Callable
15+
from contextlib import asynccontextmanager
16+
17+
import aiomqtt
18+
from aiomqtt import MqttError, TLSParameters
19+
20+
from .. import RoborockException
21+
from .session import MqttParams, MqttSession
22+
23+
_LOGGER = logging.getLogger(__name__)
24+
_MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt")
25+
26+
KEEPALIVE = 60
27+
28+
# Exponential backoff parameters
29+
MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10)
30+
MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30)
31+
BACKOFF_MULTIPLIER = 1.5
32+
33+
34+
class RoorockMqttSession(MqttSession):
35+
"""An MQTT session for sending and receiving messages.
36+
37+
You can start a session invoking the start() method which will connect to
38+
the MQTT broker. A caller may subscribe to a topic, and the session keeps
39+
track of which callbacks to invoke for each topic.
40+
41+
The client is run as a background task that will run until shutdown. Once
42+
connected, the client will wait for messages to be received in a loop. If
43+
the connection is lost, the client will be re-created and reconnected. There
44+
is backoff to avoid spamming the broker with connection attempts. The client
45+
will automatically re-establish any subscriptions when the connection is
46+
re-established.
47+
"""
48+
49+
def __init__(self, params: MqttParams):
50+
self._params = params
51+
self._background_task: asyncio.Task[None] | None = None
52+
self._healthy = False
53+
self._backoff = MIN_BACKOFF_INTERVAL
54+
self._client: aiomqtt.Client | None = None
55+
self._client_lock = asyncio.Lock()
56+
self._listeners: dict[str, list[Callable[[bytes], None]]] = {}
57+
58+
@property
59+
def connected(self) -> bool:
60+
"""True if the session is connected to the broker."""
61+
return self._healthy
62+
63+
async def start(self) -> None:
64+
"""Start the MQTT session.
65+
66+
This has special behavior for the first connection attempt where any
67+
failures are raised immediately. This is to allow the caller to
68+
handle the failure and retry if desired itself. Once connected,
69+
the session will retry connecting in the background.
70+
"""
71+
start_future: asyncio.Future[None] = asyncio.Future()
72+
loop = asyncio.get_event_loop()
73+
self._background_task = loop.create_task(self._run_task(start_future))
74+
await start_future
75+
76+
async def close(self) -> None:
77+
"""Cancels the mqtt loop and shutdown the client library."""
78+
if self._background_task:
79+
self._background_task.cancel()
80+
try:
81+
await self._background_task
82+
except asyncio.CancelledError:
83+
pass
84+
async with self._client_lock:
85+
if self._client:
86+
await self._client.close()
87+
88+
self._healthy = False
89+
90+
async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
91+
"""Run the MQTT loop."""
92+
_LOGGER.info("Starting MQTT session")
93+
while True:
94+
try:
95+
async with self._mqtt_client(self._params) as client:
96+
# Reset backoff once we've successfully connected
97+
self._backoff = MIN_BACKOFF_INTERVAL
98+
self._healthy = True
99+
if start_future:
100+
start_future.set_result(None)
101+
start_future = None
102+
103+
await self._process_message_loop(client)
104+
105+
except asyncio.CancelledError:
106+
_LOGGER.debug("MQTT loop was cancelled")
107+
return
108+
except MqttError as err:
109+
_LOGGER.info("MQTT error: %s", err)
110+
if start_future:
111+
start_future.set_exception(err)
112+
return
113+
# Catch exceptions to avoid crashing the loop
114+
# and to allow the loop to retry.
115+
except Exception as err:
116+
# This error is thrown when the MQTT loop is cancelled
117+
# and the generator is not stopped.
118+
if "generator didn't stop" in str(err):
119+
_LOGGER.debug("MQTT loop was cancelled")
120+
return
121+
_LOGGER.error("Uncaught error in MQTT session: %s", err)
122+
if start_future:
123+
start_future.set_exception(err)
124+
return
125+
126+
self._healthy = False
127+
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
128+
await asyncio.sleep(self._backoff.total_seconds())
129+
self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL)
130+
131+
@asynccontextmanager
132+
async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
133+
"""Connect to the MQTT broker and listen for messages."""
134+
_LOGGER.debug("Connecting to %s:%s for %s", params.host, params.port, params.username)
135+
try:
136+
async with aiomqtt.Client(
137+
hostname=params.host,
138+
port=params.port,
139+
username=params.username,
140+
password=params.password,
141+
keepalive=KEEPALIVE,
142+
protocol=aiomqtt.ProtocolVersion.V5,
143+
tls_params=TLSParameters() if params.tls else None,
144+
timeout=params.timeout,
145+
logger=_MQTT_LOGGER,
146+
) as client:
147+
_LOGGER.debug("Connected to MQTT broker")
148+
# Re-establish any existing subscriptions
149+
async with self._client_lock:
150+
self._client = client
151+
for topic in self._listeners:
152+
_LOGGER.debug("Re-establising subscription to topic %s", topic)
153+
await client.subscribe(topic)
154+
155+
yield client
156+
finally:
157+
async with self._client_lock:
158+
self._client = None
159+
160+
async def _process_message_loop(self, client: aiomqtt.Client) -> None:
161+
_LOGGER.debug("Processing mqtt messages")
162+
async for message in client.messages:
163+
_LOGGER.debug("Received message: %s", message)
164+
for listener in self._listeners.get(message.topic.value) or []:
165+
try:
166+
listener(message.payload)
167+
except asyncio.CancelledError:
168+
raise
169+
except Exception as e:
170+
_LOGGER.error("Uncaught exception in subscriber callback: %s", e)
171+
172+
async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
173+
"""Subscribe to messages on the specified topic and invoke the callback for new messages.
174+
175+
The callback will be called with the message payload as a bytes object. The callback
176+
should not block since it runs in the async loop. It should not raise any exceptions.
177+
178+
The returned callable unsubscribes from the topic when called.
179+
"""
180+
_LOGGER.debug("Subscribing to topic %s", topic)
181+
if topic not in self._listeners:
182+
self._listeners[topic] = []
183+
self._listeners[topic].append(callback)
184+
185+
async with self._client_lock:
186+
if self._client:
187+
_LOGGER.debug("Establishing subscription to topic %s", topic)
188+
await self._client.subscribe(topic)
189+
else:
190+
_LOGGER.debug("Client not connected, will establish subscription later")
191+
192+
return lambda: self._listeners[topic].remove(callback)
193+
194+
async def publish(self, topic: str, message: bytes) -> None:
195+
"""Publish a message on the topic."""
196+
_LOGGER.debug("Sending message to topic %s: %s", topic, message)
197+
async with self._client_lock:
198+
if not self._client:
199+
raise RoborockException("MQTT client not connected")
200+
coro = self._client.publish(topic, message)
201+
await coro
202+
203+
204+
async def create_mqtt_session(params: MqttParams) -> MqttSession:
205+
"""Create an MQTT session.
206+
207+
This function is a factory for creating an MQTT session. This will
208+
raise an exception if initial attempt to connect fails. Once connected,
209+
the session will retry connecting on failure in the background.
210+
"""
211+
session = RoorockMqttSession(params)
212+
await session.start()
213+
return session

roborock/mqtt/session.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""An MQTT session for sending and receiving messages."""
2+
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Callable
5+
from dataclasses import dataclass
6+
7+
DEFAULT_TIMEOUT = 30.0
8+
9+
10+
@dataclass
11+
class MqttParams:
12+
"""MQTT parameters for the connection."""
13+
14+
host: str
15+
"""MQTT host to connect to."""
16+
17+
port: int
18+
"""MQTT port to connect to."""
19+
20+
tls: bool
21+
"""Use TLS for the connection."""
22+
23+
username: str
24+
"""MQTT username to use for authentication."""
25+
26+
password: str
27+
"""MQTT password to use for authentication."""
28+
29+
timeout: float = DEFAULT_TIMEOUT
30+
"""Timeout for communications with the broker in seconds."""
31+
32+
33+
class MqttSession(ABC):
34+
"""An MQTT session for sending and receiving messages."""
35+
36+
@property
37+
@abstractmethod
38+
def connected(self) -> bool:
39+
"""True if the session is connected to the broker."""
40+
41+
@abstractmethod
42+
async def subscribe(self, device_id: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
43+
"""Invoke the callback when messages are received on the topic.
44+
45+
The returned callable unsubscribes from the topic when called.
46+
"""
47+
48+
@abstractmethod
49+
async def publish(self, topic: str, message: bytes) -> None:
50+
"""Publish a message on the specified topic.
51+
52+
This will raise an exception if the message could not be sent.
53+
"""
54+
55+
@abstractmethod
56+
async def close(self) -> None:
57+
"""Cancels the mqtt loop"""

tests/conftest.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ class FakeSocketHandler:
3333
handle request callback handles the incoming requests and prepares the responses.
3434
"""
3535

36-
def __init__(self, handle_request: RequestHandler) -> None:
36+
def __init__(self, handle_request: RequestHandler, response_queue: Queue[bytes]) -> None:
3737
self.response_buf = io.BytesIO()
3838
self.handle_request = handle_request
39+
self.response_queue = response_queue
3940

4041
def pending(self) -> int:
4142
"""Return the number of bytes in the response buffer."""
@@ -62,9 +63,17 @@ def handle_socket_send(self, client_request: bytes) -> int:
6263
# The buffer will be emptied when the client calls recv() on the socket
6364
_LOGGER.debug("Queued: 0x%s", response.hex())
6465
self.response_buf.write(response)
65-
6666
return len(client_request)
6767

68+
def push_response(self) -> None:
69+
"""Push a response to the client."""
70+
if not self.response_queue.empty():
71+
response = self.response_queue.get()
72+
# Enqueue a response to be sent back to the client in the buffer.
73+
# The buffer will be emptied when the client calls recv() on the socket
74+
_LOGGER.debug("Queued: 0x%s", response.hex())
75+
self.response_buf.write(response)
76+
6877

6978
@pytest.fixture(name="received_requests")
7079
def received_requests_fixture() -> Queue[bytes]:
@@ -97,9 +106,9 @@ def handle_request(client_request: bytes) -> bytes | None:
97106

98107

99108
@pytest.fixture(name="fake_socket_handler")
100-
def fake_socket_handler_fixture(request_handler: RequestHandler) -> FakeSocketHandler:
109+
def fake_socket_handler_fixture(request_handler: RequestHandler, response_queue: Queue[bytes]) -> FakeSocketHandler:
101110
"""Fixture that creates a fake MQTT broker."""
102-
return FakeSocketHandler(request_handler)
111+
return FakeSocketHandler(request_handler, response_queue)
103112

104113

105114
@pytest.fixture(name="mock_sock")
@@ -109,6 +118,7 @@ def mock_sock_fixture(fake_socket_handler: FakeSocketHandler) -> Mock:
109118
mock_sock.recv = fake_socket_handler.handle_socket_recv
110119
mock_sock.send = fake_socket_handler.handle_socket_send
111120
mock_sock.pending = fake_socket_handler.pending
121+
mock_sock.fileno = lambda: 1
112122
return mock_sock
113123

114124

0 commit comments

Comments
 (0)