Skip to content

Add an aiomqtt based MQTT session module #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ paho-mqtt = ">=1.6.1,<3.0.0"
construct = "^2.10.57"
vacuum-map-parser-roborock = "*"
pyrate-limiter = "^3.7.0"
aiomqtt = "^2.3.2"


[build-system]
Expand Down
7 changes: 7 additions & 0 deletions roborock/mqtt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""This module contains the low level MQTT client for the Roborock vacuum cleaner.

This is not meant to be used directly, but rather as a base for the higher level
modules.
"""

__all__: list[str] = []
234 changes: 234 additions & 0 deletions roborock/mqtt/roborock_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""An MQTT session for sending and receiving messages.

See create_mqtt_session for a factory function to create an MQTT session.

This is a thin wrapper around the async MQTT client that handles dispatching messages
from a topic to a callback function, since the async MQTT client does not
support this out of the box. It also handles the authentication process and
receiving messages from the vacuum cleaner.
"""

import asyncio
import datetime
import logging
from collections.abc import Callable
from contextlib import asynccontextmanager

import aiomqtt
from aiomqtt import MqttError, TLSParameters

from .session import MqttParams, MqttSession, MqttSessionException

_LOGGER = logging.getLogger(__name__)
_MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt")

KEEPALIVE = 60

# Exponential backoff parameters
MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10)
MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30)
BACKOFF_MULTIPLIER = 1.5


class RoborockMqttSession(MqttSession):
"""An MQTT session for sending and receiving messages.

You can start a session invoking the start() method which will connect to
the MQTT broker. A caller may subscribe to a topic, and the session keeps
track of which callbacks to invoke for each topic.

The client is run as a background task that will run until shutdown. Once
connected, the client will wait for messages to be received in a loop. If
the connection is lost, the client will be re-created and reconnected. There
is backoff to avoid spamming the broker with connection attempts. The client
will automatically re-establish any subscriptions when the connection is
re-established.
"""

def __init__(self, params: MqttParams):
self._params = params
self._background_task: asyncio.Task[None] | None = None
self._healthy = False
self._backoff = MIN_BACKOFF_INTERVAL
self._client: aiomqtt.Client | None = None
self._client_lock = asyncio.Lock()
self._listeners: dict[str, list[Callable[[bytes], None]]] = {}

@property
def connected(self) -> bool:
"""True if the session is connected to the broker."""
return self._healthy

async def start(self) -> None:
"""Start the MQTT session.

This has special behavior for the first connection attempt where any
failures are raised immediately. This is to allow the caller to
handle the failure and retry if desired itself. Once connected,
the session will retry connecting in the background.
"""
start_future: asyncio.Future[None] = asyncio.Future()
loop = asyncio.get_event_loop()
self._background_task = loop.create_task(self._run_task(start_future))
try:
await start_future
except MqttError as err:
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
except Exception as err:
raise MqttSessionException(f"Unexpected error starting session: {err}") from err
else:
_LOGGER.debug("MQTT session started successfully")

async def close(self) -> None:
"""Cancels the MQTT loop and shutdown the client library."""
if self._background_task:
self._background_task.cancel()
try:
await self._background_task
except asyncio.CancelledError:
pass
async with self._client_lock:
if self._client:
await self._client.close()

self._healthy = False

async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
"""Run the MQTT loop."""
_LOGGER.info("Starting MQTT session")
while True:
try:
async with self._mqtt_client(self._params) as client:
# Reset backoff once we've successfully connected
self._backoff = MIN_BACKOFF_INTERVAL
self._healthy = True
if start_future:
start_future.set_result(None)
start_future = None

await self._process_message_loop(client)

except MqttError as err:
if start_future:
_LOGGER.info("MQTT error starting session: %s", err)
start_future.set_exception(err)
return
_LOGGER.info("MQTT error: %s", err)
except asyncio.CancelledError as err:
if start_future:
_LOGGER.debug("MQTT loop was cancelled")
start_future.set_exception(err)
_LOGGER.debug("MQTT loop was cancelled whiel starting")
return
# Catch exceptions to avoid crashing the loop
# and to allow the loop to retry.
except Exception as err:
# This error is thrown when the MQTT loop is cancelled
# and the generator is not stopped.
if "generator didn't stop" in str(err):
_LOGGER.debug("MQTT loop was cancelled")
return
if start_future:
_LOGGER.error("Uncaught error starting MQTT session: %s", err)
start_future.set_exception(err)
return
_LOGGER.error("Uncaught error during MQTT session: %s", err)

self._healthy = False
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
await asyncio.sleep(self._backoff.total_seconds())
self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL)

@asynccontextmanager
async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
"""Connect to the MQTT broker and listen for messages."""
_LOGGER.debug("Connecting to %s:%s for %s", params.host, params.port, params.username)
try:
async with aiomqtt.Client(
hostname=params.host,
port=params.port,
username=params.username,
password=params.password,
keepalive=KEEPALIVE,
protocol=aiomqtt.ProtocolVersion.V5,
tls_params=TLSParameters() if params.tls else None,
timeout=params.timeout,
logger=_MQTT_LOGGER,
) as client:
_LOGGER.debug("Connected to MQTT broker")
# Re-establish any existing subscriptions
async with self._client_lock:
self._client = client
for topic in self._listeners:
_LOGGER.debug("Re-establising subscription to topic %s", topic)
# TODO: If this fails it will break the whole connection. Make
# this retry again in the background with backoff.
await client.subscribe(topic)

yield client
finally:
async with self._client_lock:
self._client = None

async def _process_message_loop(self, client: aiomqtt.Client) -> None:
_LOGGER.debug("client=%s", client)
_LOGGER.debug("Processing MQTT messages: %s", client.messages)
async for message in client.messages:
_LOGGER.debug("Received message: %s", message)
for listener in self._listeners.get(message.topic.value, []):
try:
listener(message.payload)
except asyncio.CancelledError:
raise
except Exception as e:
_LOGGER.error("Uncaught exception in subscriber callback: %s", e)

async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
"""Subscribe to messages on the specified topic and invoke the callback for new messages.

The callback will be called with the message payload as a bytes object. The callback
should not block since it runs in the async loop. It should not raise any exceptions.

The returned callable unsubscribes from the topic when called.
"""
_LOGGER.debug("Subscribing to topic %s", topic)
if topic not in self._listeners:
self._listeners[topic] = []
self._listeners[topic].append(callback)

async with self._client_lock:
if self._client:
_LOGGER.debug("Establishing subscription to topic %s", topic)
try:
await self._client.subscribe(topic)
except MqttError as err:
raise MqttSessionException(f"Error subscribing to topic: {err}") from err
else:
_LOGGER.debug("Client not connected, will establish subscription later")

return lambda: self._listeners[topic].remove(callback)

async def publish(self, topic: str, message: bytes) -> None:
"""Publish a message on the topic."""
_LOGGER.debug("Sending message to topic %s: %s", topic, message)
client: aiomqtt.Client
async with self._client_lock:
if self._client is None:
raise MqttSessionException("Could not publish message, MQTT client not connected")
client = self._client
try:
await client.publish(topic, message)
except MqttError as err:
raise MqttSessionException(f"Error publishing message: {err}") from err


async def create_mqtt_session(params: MqttParams) -> MqttSession:
"""Create an MQTT session.

This function is a factory for creating an MQTT session. This will
raise an exception if initial attempt to connect fails. Once connected,
the session will retry connecting on failure in the background.
"""
session = RoborockMqttSession(params)
await session.start()
return session
63 changes: 63 additions & 0 deletions roborock/mqtt/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""An MQTT session for sending and receiving messages."""

from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass

from roborock.exceptions import RoborockException

DEFAULT_TIMEOUT = 30.0


@dataclass
class MqttParams:
"""MQTT parameters for the connection."""

host: str
"""MQTT host to connect to."""

port: int
"""MQTT port to connect to."""

tls: bool
"""Use TLS for the connection."""

username: str
"""MQTT username to use for authentication."""

password: str
"""MQTT password to use for authentication."""

timeout: float = DEFAULT_TIMEOUT
"""Timeout for communications with the broker in seconds."""


class MqttSession(ABC):
"""An MQTT session for sending and receiving messages."""

@property
@abstractmethod
def connected(self) -> bool:
"""True if the session is connected to the broker."""

@abstractmethod
async def subscribe(self, device_id: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
"""Invoke the callback when messages are received on the topic.

The returned callable unsubscribes from the topic when called.
"""

@abstractmethod
async def publish(self, topic: str, message: bytes) -> None:
"""Publish a message on the specified topic.

This will raise an exception if the message could not be sent.
"""

@abstractmethod
async def close(self) -> None:
"""Cancels the mqtt loop"""


class MqttSessionException(RoborockException):
""" "Raised when there is an error communicating with MQTT."""
17 changes: 13 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ class FakeSocketHandler:
handle request callback handles the incoming requests and prepares the responses.
"""

def __init__(self, handle_request: RequestHandler) -> None:
def __init__(self, handle_request: RequestHandler, response_queue: Queue[bytes]) -> None:
self.response_buf = io.BytesIO()
self.handle_request = handle_request
self.response_queue = response_queue

def pending(self) -> int:
"""Return the number of bytes in the response buffer."""
Expand All @@ -62,9 +63,17 @@ def handle_socket_send(self, client_request: bytes) -> int:
# The buffer will be emptied when the client calls recv() on the socket
_LOGGER.debug("Queued: 0x%s", response.hex())
self.response_buf.write(response)

return len(client_request)

def push_response(self) -> None:
"""Push a response to the client."""
if not self.response_queue.empty():
response = self.response_queue.get()
# Enqueue a response to be sent back to the client in the buffer.
# The buffer will be emptied when the client calls recv() on the socket
_LOGGER.debug("Queued: 0x%s", response.hex())
self.response_buf.write(response)


@pytest.fixture(name="received_requests")
def received_requests_fixture() -> Queue[bytes]:
Expand Down Expand Up @@ -97,9 +106,9 @@ def handle_request(client_request: bytes) -> bytes | None:


@pytest.fixture(name="fake_socket_handler")
def fake_socket_handler_fixture(request_handler: RequestHandler) -> FakeSocketHandler:
def fake_socket_handler_fixture(request_handler: RequestHandler, response_queue: Queue[bytes]) -> FakeSocketHandler:
"""Fixture that creates a fake MQTT broker."""
return FakeSocketHandler(request_handler)
return FakeSocketHandler(request_handler, response_queue)


@pytest.fixture(name="mock_sock")
Expand Down
Loading