From 22b9a3f24add6642f75c0ccc54ecbc39b90ab819 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 12 Apr 2025 16:43:21 -0700 Subject: [PATCH 1/2] feat: Add an aiomqtt based MQTT session module --- poetry.lock | 18 ++- pyproject.toml | 1 + roborock/mqtt/__init__.py | 7 + roborock/mqtt/roborock_session.py | 213 ++++++++++++++++++++++++++++ roborock/mqtt/session.py | 57 ++++++++ tests/conftest.py | 17 ++- tests/mqtt/test_roborock_session.py | 142 +++++++++++++++++++ tests/mqtt_packet.py | 4 +- 8 files changed, 451 insertions(+), 8 deletions(-) create mode 100644 roborock/mqtt/__init__.py create mode 100644 roborock/mqtt/roborock_session.py create mode 100644 roborock/mqtt/session.py create mode 100644 tests/mqtt/test_roborock_session.py diff --git a/poetry.lock b/poetry.lock index bc9b1278..0d79c176 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -113,6 +113,20 @@ yarl = ">=1.17.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] +[[package]] +name = "aiomqtt" +version = "2.3.2" +description = "The idiomatic asyncio MQTT client" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "aiomqtt-2.3.2-py3-none-any.whl", hash = "sha256:e67f877454b04437732a7eb005d8f8751df1ba7931b2eb1b7a7d8bf7e50c96e7"}, + {file = "aiomqtt-2.3.2.tar.gz", hash = "sha256:96d979aeac930f031b0efa4c8e71ab337b3b330cf175b9329905419a38d8509c"}, +] + +[package.dependencies] +paho-mqtt = ">=2.1.0,<3.0.0" + [[package]] name = "aioresponses" version = "0.7.8" @@ -1476,4 +1490,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "6b418ac16ea0c5d153d97187f3d5330eeb43d914e3001a6ff58a05978880b8ab" +content-hash = "539acd1831188994429cea11425e53ec1b851f17f7f7e92ed466865190fa80c4" diff --git a/pyproject.toml b/pyproject.toml index a64f5d69..cc19720a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ paho-mqtt = ">=1.6.1,<3.0.0" construct = "^2.10.57" vacuum-map-parser-roborock = "*" pyrate-limiter = "^3.7.0" +aiomqtt = "^2.3.2" [build-system] diff --git a/roborock/mqtt/__init__.py b/roborock/mqtt/__init__.py new file mode 100644 index 00000000..b9fb6d25 --- /dev/null +++ b/roborock/mqtt/__init__.py @@ -0,0 +1,7 @@ +"""This module contains the low level MQTT client for the Roborock vacuum cleaner. + +This is not meant to be used directly, but rather as a base for the higher level +modules. +""" + +__all__: list[str] = [] diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py new file mode 100644 index 00000000..6e77e575 --- /dev/null +++ b/roborock/mqtt/roborock_session.py @@ -0,0 +1,213 @@ +"""An MQTT session for sending and receiving messages. + +See create_mqtt_session for a factory function to create an MQTT session. + +This is a thin wrapper around the async MQTT client that handles dispatching messages +from a topic to a callback function, since the async MQTT client does not +support this out of the box. It also handles the authentication process and +receiving messages from the vacuum cleaner. +""" + +import asyncio +import datetime +import logging +from collections.abc import Callable +from contextlib import asynccontextmanager + +import aiomqtt +from aiomqtt import MqttError, TLSParameters + +from .. import RoborockException +from .session import MqttParams, MqttSession + +_LOGGER = logging.getLogger(__name__) +_MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt") + +KEEPALIVE = 60 + +# Exponential backoff parameters +MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10) +MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30) +BACKOFF_MULTIPLIER = 1.5 + + +class RoborockMqttSession(MqttSession): + """An MQTT session for sending and receiving messages. + + You can start a session invoking the start() method which will connect to + the MQTT broker. A caller may subscribe to a topic, and the session keeps + track of which callbacks to invoke for each topic. + + The client is run as a background task that will run until shutdown. Once + connected, the client will wait for messages to be received in a loop. If + the connection is lost, the client will be re-created and reconnected. There + is backoff to avoid spamming the broker with connection attempts. The client + will automatically re-establish any subscriptions when the connection is + re-established. + """ + + def __init__(self, params: MqttParams): + self._params = params + self._background_task: asyncio.Task[None] | None = None + self._healthy = False + self._backoff = MIN_BACKOFF_INTERVAL + self._client: aiomqtt.Client | None = None + self._client_lock = asyncio.Lock() + self._listeners: dict[str, list[Callable[[bytes], None]]] = {} + + @property + def connected(self) -> bool: + """True if the session is connected to the broker.""" + return self._healthy + + async def start(self) -> None: + """Start the MQTT session. + + This has special behavior for the first connection attempt where any + failures are raised immediately. This is to allow the caller to + handle the failure and retry if desired itself. Once connected, + the session will retry connecting in the background. + """ + start_future: asyncio.Future[None] = asyncio.Future() + loop = asyncio.get_event_loop() + self._background_task = loop.create_task(self._run_task(start_future)) + await start_future + + async def close(self) -> None: + """Cancels the MQTT loop and shutdown the client library.""" + if self._background_task: + self._background_task.cancel() + try: + await self._background_task + except asyncio.CancelledError: + pass + async with self._client_lock: + if self._client: + await self._client.close() + + self._healthy = False + + async def _run_task(self, start_future: asyncio.Future[None] | None) -> None: + """Run the MQTT loop.""" + _LOGGER.info("Starting MQTT session") + while True: + try: + async with self._mqtt_client(self._params) as client: + # Reset backoff once we've successfully connected + self._backoff = MIN_BACKOFF_INTERVAL + self._healthy = True + if start_future: + start_future.set_result(None) + start_future = None + + await self._process_message_loop(client) + + except asyncio.CancelledError: + _LOGGER.debug("MQTT loop was cancelled") + return + except MqttError as err: + _LOGGER.info("MQTT error: %s", err) + if start_future: + start_future.set_exception(err) + return + # Catch exceptions to avoid crashing the loop + # and to allow the loop to retry. + except Exception as err: + # This error is thrown when the MQTT loop is cancelled + # and the generator is not stopped. + if "generator didn't stop" in str(err): + _LOGGER.debug("MQTT loop was cancelled") + return + _LOGGER.error("Uncaught error in MQTT session: %s", err) + if start_future: + start_future.set_exception(err) + return + + self._healthy = False + _LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds()) + await asyncio.sleep(self._backoff.total_seconds()) + self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL) + + @asynccontextmanager + async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client: + """Connect to the MQTT broker and listen for messages.""" + _LOGGER.debug("Connecting to %s:%s for %s", params.host, params.port, params.username) + try: + async with aiomqtt.Client( + hostname=params.host, + port=params.port, + username=params.username, + password=params.password, + keepalive=KEEPALIVE, + protocol=aiomqtt.ProtocolVersion.V5, + tls_params=TLSParameters() if params.tls else None, + timeout=params.timeout, + logger=_MQTT_LOGGER, + ) as client: + _LOGGER.debug("Connected to MQTT broker") + # Re-establish any existing subscriptions + async with self._client_lock: + self._client = client + for topic in self._listeners: + _LOGGER.debug("Re-establising subscription to topic %s", topic) + await client.subscribe(topic) + + yield client + finally: + async with self._client_lock: + self._client = None + + async def _process_message_loop(self, client: aiomqtt.Client) -> None: + _LOGGER.debug("Processing MQTT messages") + async for message in client.messages: + _LOGGER.debug("Received message: %s", message) + for listener in self._listeners.get(message.topic.value) or []: + try: + listener(message.payload) + except asyncio.CancelledError: + raise + except Exception as e: + _LOGGER.error("Uncaught exception in subscriber callback: %s", e) + + async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]: + """Subscribe to messages on the specified topic and invoke the callback for new messages. + + The callback will be called with the message payload as a bytes object. The callback + should not block since it runs in the async loop. It should not raise any exceptions. + + The returned callable unsubscribes from the topic when called. + """ + _LOGGER.debug("Subscribing to topic %s", topic) + if topic not in self._listeners: + self._listeners[topic] = [] + self._listeners[topic].append(callback) + + async with self._client_lock: + if self._client: + _LOGGER.debug("Establishing subscription to topic %s", topic) + await self._client.subscribe(topic) + else: + _LOGGER.debug("Client not connected, will establish subscription later") + + return lambda: self._listeners[topic].remove(callback) + + async def publish(self, topic: str, message: bytes) -> None: + """Publish a message on the topic.""" + _LOGGER.debug("Sending message to topic %s: %s", topic, message) + async with self._client_lock: + if not self._client: + raise RoborockException("MQTT client not connected") + coro = self._client.publish(topic, message) + await coro + + +async def create_mqtt_session(params: MqttParams) -> MqttSession: + """Create an MQTT session. + + This function is a factory for creating an MQTT session. This will + raise an exception if initial attempt to connect fails. Once connected, + the session will retry connecting on failure in the background. + """ + session = RoborockMqttSession(params) + await session.start() + return session diff --git a/roborock/mqtt/session.py b/roborock/mqtt/session.py new file mode 100644 index 00000000..75c971db --- /dev/null +++ b/roborock/mqtt/session.py @@ -0,0 +1,57 @@ +"""An MQTT session for sending and receiving messages.""" + +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass + +DEFAULT_TIMEOUT = 30.0 + + +@dataclass +class MqttParams: + """MQTT parameters for the connection.""" + + host: str + """MQTT host to connect to.""" + + port: int + """MQTT port to connect to.""" + + tls: bool + """Use TLS for the connection.""" + + username: str + """MQTT username to use for authentication.""" + + password: str + """MQTT password to use for authentication.""" + + timeout: float = DEFAULT_TIMEOUT + """Timeout for communications with the broker in seconds.""" + + +class MqttSession(ABC): + """An MQTT session for sending and receiving messages.""" + + @property + @abstractmethod + def connected(self) -> bool: + """True if the session is connected to the broker.""" + + @abstractmethod + async def subscribe(self, device_id: str, callback: Callable[[bytes], None]) -> Callable[[], None]: + """Invoke the callback when messages are received on the topic. + + The returned callable unsubscribes from the topic when called. + """ + + @abstractmethod + async def publish(self, topic: str, message: bytes) -> None: + """Publish a message on the specified topic. + + This will raise an exception if the message could not be sent. + """ + + @abstractmethod + async def close(self) -> None: + """Cancels the mqtt loop""" diff --git a/tests/conftest.py b/tests/conftest.py index 906f9cf9..383acc64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,9 +33,10 @@ class FakeSocketHandler: handle request callback handles the incoming requests and prepares the responses. """ - def __init__(self, handle_request: RequestHandler) -> None: + def __init__(self, handle_request: RequestHandler, response_queue: Queue[bytes]) -> None: self.response_buf = io.BytesIO() self.handle_request = handle_request + self.response_queue = response_queue def pending(self) -> int: """Return the number of bytes in the response buffer.""" @@ -62,9 +63,17 @@ def handle_socket_send(self, client_request: bytes) -> int: # The buffer will be emptied when the client calls recv() on the socket _LOGGER.debug("Queued: 0x%s", response.hex()) self.response_buf.write(response) - return len(client_request) + def push_response(self) -> None: + """Push a response to the client.""" + if not self.response_queue.empty(): + response = self.response_queue.get() + # Enqueue a response to be sent back to the client in the buffer. + # The buffer will be emptied when the client calls recv() on the socket + _LOGGER.debug("Queued: 0x%s", response.hex()) + self.response_buf.write(response) + @pytest.fixture(name="received_requests") def received_requests_fixture() -> Queue[bytes]: @@ -97,9 +106,9 @@ def handle_request(client_request: bytes) -> bytes | None: @pytest.fixture(name="fake_socket_handler") -def fake_socket_handler_fixture(request_handler: RequestHandler) -> FakeSocketHandler: +def fake_socket_handler_fixture(request_handler: RequestHandler, response_queue: Queue[bytes]) -> FakeSocketHandler: """Fixture that creates a fake MQTT broker.""" - return FakeSocketHandler(request_handler) + return FakeSocketHandler(request_handler, response_queue) @pytest.fixture(name="mock_sock") diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py new file mode 100644 index 00000000..6091e955 --- /dev/null +++ b/tests/mqtt/test_roborock_session.py @@ -0,0 +1,142 @@ +"""Tests for the MQTT session module.""" + +import asyncio +from collections.abc import Callable, Generator +from queue import Queue +from typing import Any +from unittest.mock import patch + +import paho.mqtt.client as mqtt +import pytest + +from roborock.mqtt.roborock_session import create_mqtt_session +from roborock.mqtt.session import MqttParams +from tests import mqtt_packet +from tests.conftest import FakeSocketHandler + +# We mock out the connection so these params are not used/verified +FAKE_PARAMS = MqttParams( + host="localhost", + port=1883, + tls=False, + username="username", + password="password", + timeout=10.0, +) + + +@pytest.fixture(autouse=True) +def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None: + """Fixture to prepare a fake MQTT server.""" + + +@pytest.fixture(autouse=True) +def mock_client_fixture(event_loop: asyncio.AbstractEventLoop) -> Generator[None, None, None]: + """Fixture to patch the MQTT underlying sync client. + + The tests use fake sockets, so this ensures that the async mqtt client does not + attempt to listen on them directly. We instead just poll the socket for + data ourselves. + """ + + orig_class = mqtt.Client + + async def poll_sockets(client: mqtt.Client) -> None: + """Poll the mqtt client sockets in a loop to pick up new data.""" + while True: + event_loop.call_soon_threadsafe(client.loop_read) + event_loop.call_soon_threadsafe(client.loop_write) + await asyncio.sleep(0.1) + + task: asyncio.Task[None] | None = None + + def new_client(*args: Any, **kwargs: Any) -> mqtt.Client: + """Create a new mqtt client and start the socket polling task.""" + nonlocal task + client = orig_class(*args, **kwargs) + task = event_loop.create_task(poll_sockets(client)) + return client + + with patch("aiomqtt.client.Client._on_socket_open"), patch("aiomqtt.client.Client._on_socket_close"), patch( + "aiomqtt.client.Client._on_socket_register_write" + ), patch("aiomqtt.client.Client._on_socket_unregister_write"), patch( + "aiomqtt.client.mqtt.Client", side_effect=new_client + ): + yield + if task: + task.cancel() + + +@pytest.fixture +def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]: + """Fixtures to push messages.""" + + def push(message: bytes) -> None: + response_queue.put(message) + fake_socket_handler.push_response() + + return push + + +class Subscriber: + """Mock subscriber class.""" + + def __init__(self) -> None: + """Initialize the subscriber.""" + self.messages: list[bytes] = [] + self.event: asyncio.Event = asyncio.Event() + + def append(self, message: bytes) -> None: + """Append a message to the subscriber.""" + self.messages.append(message) + self.event.set() + + async def wait(self) -> None: + """Wait for a message to be received.""" + await self.event.wait() + self.event.clear() + + +async def test_session(push_response: Callable[[bytes], None]) -> None: + """Test the MQTT session.""" + + push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + session = await create_mqtt_session(FAKE_PARAMS) + + push_response(mqtt_packet.gen_suback(mid=1)) + subscriber1 = Subscriber() + unsub1 = await session.subscribe("topic-1", subscriber1.append) + + push_response(mqtt_packet.gen_suback(mid=2)) + subscriber2 = Subscriber() + await session.subscribe("topic-2", subscriber2.append) + + push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + await subscriber1.wait() + assert subscriber1.messages == [b"12345"] + assert not subscriber2.messages + + push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) + await subscriber2.wait() + assert subscriber2.messages == [b"67890"] + + push_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) + await subscriber1.wait() + assert subscriber1.messages == [b"12345", b"ABC"] + assert subscriber2.messages == [b"67890"] + + # Messages are no longer received after unsubscribing + unsub1() + push_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) + assert subscriber1.messages == [b"12345", b"ABC"] + + await session.close() + + +async def test_publish_command(push_response: Callable[[bytes], None]) -> None: + """Test publishing during an MQTT session.""" + + push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + session = await create_mqtt_session(FAKE_PARAMS) + + await session.publish("topic-1", message=b"payload") diff --git a/tests/mqtt_packet.py b/tests/mqtt_packet.py index 139671a6..2fc19a12 100644 --- a/tests/mqtt_packet.py +++ b/tests/mqtt_packet.py @@ -56,7 +56,7 @@ def gen_connack(flags=0, rc=0, properties=b"", property_helper=True): return packet -def gen_suback(mid: int, qos: int) -> bytes: +def gen_suback(mid: int, qos: int = 0) -> bytes: """Generate a SUBACK packet.""" return struct.pack("!BBHBB", 144, 2 + 1 + 1, mid, 0, qos) @@ -74,7 +74,7 @@ def _gen_command_with_mid(cmd: int, mid: int, reason_code: int = 0) -> bytes: return struct.pack("!BBHB", cmd, 3, mid, reason_code) -def gen_puback(mid: int, reason_code: int = -1) -> bytes: +def gen_puback(mid: int, reason_code: int = 0) -> bytes: """Generate a PUBACK packet.""" return _gen_command_with_mid(64, mid, reason_code) From 0cf621fb4bc69aec632891a82d5cb6e0a3d62468 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 14 Apr 2025 12:49:44 -0700 Subject: [PATCH 2/2] feat: Add exception handling and increased test coverage --- roborock/mqtt/roborock_session.py | 51 +++++++++++----- roborock/mqtt/session.py | 6 ++ tests/mqtt/test_roborock_session.py | 92 ++++++++++++++++++++++++++++- 3 files changed, 131 insertions(+), 18 deletions(-) diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index 6e77e575..bf48970a 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -17,8 +17,7 @@ import aiomqtt from aiomqtt import MqttError, TLSParameters -from .. import RoborockException -from .session import MqttParams, MqttSession +from .session import MqttParams, MqttSession, MqttSessionException _LOGGER = logging.getLogger(__name__) _MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt") @@ -71,7 +70,14 @@ async def start(self) -> None: start_future: asyncio.Future[None] = asyncio.Future() loop = asyncio.get_event_loop() self._background_task = loop.create_task(self._run_task(start_future)) - await start_future + try: + await start_future + except MqttError as err: + raise MqttSessionException(f"Error starting MQTT session: {err}") from err + except Exception as err: + raise MqttSessionException(f"Unexpected error starting session: {err}") from err + else: + _LOGGER.debug("MQTT session started successfully") async def close(self) -> None: """Cancels the MQTT loop and shutdown the client library.""" @@ -102,14 +108,18 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None: await self._process_message_loop(client) - except asyncio.CancelledError: - _LOGGER.debug("MQTT loop was cancelled") - return except MqttError as err: - _LOGGER.info("MQTT error: %s", err) if start_future: + _LOGGER.info("MQTT error starting session: %s", err) start_future.set_exception(err) return + _LOGGER.info("MQTT error: %s", err) + except asyncio.CancelledError as err: + if start_future: + _LOGGER.debug("MQTT loop was cancelled") + start_future.set_exception(err) + _LOGGER.debug("MQTT loop was cancelled whiel starting") + return # Catch exceptions to avoid crashing the loop # and to allow the loop to retry. except Exception as err: @@ -118,10 +128,11 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None: if "generator didn't stop" in str(err): _LOGGER.debug("MQTT loop was cancelled") return - _LOGGER.error("Uncaught error in MQTT session: %s", err) if start_future: + _LOGGER.error("Uncaught error starting MQTT session: %s", err) start_future.set_exception(err) return + _LOGGER.error("Uncaught error during MQTT session: %s", err) self._healthy = False _LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds()) @@ -150,6 +161,8 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client: self._client = client for topic in self._listeners: _LOGGER.debug("Re-establising subscription to topic %s", topic) + # TODO: If this fails it will break the whole connection. Make + # this retry again in the background with backoff. await client.subscribe(topic) yield client @@ -158,10 +171,11 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client: self._client = None async def _process_message_loop(self, client: aiomqtt.Client) -> None: - _LOGGER.debug("Processing MQTT messages") + _LOGGER.debug("client=%s", client) + _LOGGER.debug("Processing MQTT messages: %s", client.messages) async for message in client.messages: _LOGGER.debug("Received message: %s", message) - for listener in self._listeners.get(message.topic.value) or []: + for listener in self._listeners.get(message.topic.value, []): try: listener(message.payload) except asyncio.CancelledError: @@ -185,7 +199,10 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call async with self._client_lock: if self._client: _LOGGER.debug("Establishing subscription to topic %s", topic) - await self._client.subscribe(topic) + try: + await self._client.subscribe(topic) + except MqttError as err: + raise MqttSessionException(f"Error subscribing to topic: {err}") from err else: _LOGGER.debug("Client not connected, will establish subscription later") @@ -194,11 +211,15 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call async def publish(self, topic: str, message: bytes) -> None: """Publish a message on the topic.""" _LOGGER.debug("Sending message to topic %s: %s", topic, message) + client: aiomqtt.Client async with self._client_lock: - if not self._client: - raise RoborockException("MQTT client not connected") - coro = self._client.publish(topic, message) - await coro + if self._client is None: + raise MqttSessionException("Could not publish message, MQTT client not connected") + client = self._client + try: + await client.publish(topic, message) + except MqttError as err: + raise MqttSessionException(f"Error publishing message: {err}") from err async def create_mqtt_session(params: MqttParams) -> MqttSession: diff --git a/roborock/mqtt/session.py b/roborock/mqtt/session.py index 75c971db..c72e3294 100644 --- a/roborock/mqtt/session.py +++ b/roborock/mqtt/session.py @@ -4,6 +4,8 @@ from collections.abc import Callable from dataclasses import dataclass +from roborock.exceptions import RoborockException + DEFAULT_TIMEOUT = 30.0 @@ -55,3 +57,7 @@ async def publish(self, topic: str, message: bytes) -> None: @abstractmethod async def close(self) -> None: """Cancels the mqtt loop""" + + +class MqttSessionException(RoborockException): + """ "Raised when there is an error communicating with MQTT.""" diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index 6091e955..4b3ad0ef 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -4,13 +4,14 @@ from collections.abc import Callable, Generator from queue import Queue from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch +import aiomqtt import paho.mqtt.client as mqtt import pytest from roborock.mqtt.roborock_session import create_mqtt_session -from roborock.mqtt.session import MqttParams +from roborock.mqtt.session import MqttParams, MqttSessionException from tests import mqtt_packet from tests.conftest import FakeSocketHandler @@ -79,7 +80,11 @@ def push(message: bytes) -> None: class Subscriber: - """Mock subscriber class.""" + """Mock subscriber class. + + This will capture messages published on the session so the tests can verify + they were received. + """ def __init__(self) -> None: """Initialize the subscriber.""" @@ -102,6 +107,7 @@ async def test_session(push_response: Callable[[bytes], None]) -> None: push_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) + assert session.connected push_response(mqtt_packet.gen_suback(mid=1)) subscriber1 = Subscriber() @@ -130,7 +136,22 @@ async def test_session(push_response: Callable[[bytes], None]) -> None: push_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) assert subscriber1.messages == [b"12345", b"ABC"] + assert session.connected await session.close() + assert not session.connected + + +async def test_session_no_subscribers(push_response: Callable[[bytes], None]) -> None: + """Test the MQTT session.""" + + push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) + session = await create_mqtt_session(FAKE_PARAMS) + assert session.connected + + await session.close() + assert not session.connected async def test_publish_command(push_response: Callable[[bytes], None]) -> None: @@ -139,4 +160,69 @@ async def test_publish_command(push_response: Callable[[bytes], None]) -> None: push_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) + push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) await session.publish("topic-1", message=b"payload") + + assert session.connected + await session.close() + assert not session.connected + + +class FakeAsyncIterator: + """Fake async iterator that waits for messages to arrive, but they never do. + + This is used for testing exceptions in other client functions. + """ + + def __aiter__(self): + return self + + async def __anext__(self) -> None: + """Iterator that does not generate any messages.""" + while True: + await asyncio.sleep(1) + + +async def test_publish_failure() -> None: + """Test an MQTT error is received when publishing a message.""" + + mock_client = AsyncMock() + mock_client.messages = FakeAsyncIterator() + + mock_aenter = AsyncMock() + mock_aenter.return_value = mock_client + + with patch("roborock.mqtt.roborock_session.aiomqtt.Client.__aenter__", mock_aenter): + session = await create_mqtt_session(FAKE_PARAMS) + assert session.connected + + mock_client.publish.side_effect = aiomqtt.MqttError + + with pytest.raises(MqttSessionException, match="Error publishing message"): + await session.publish("topic-1", message=b"payload") + + +async def test_subscribe_failure() -> None: + """Test an MQTT error while subscribing.""" + + mock_client = AsyncMock() + mock_client.messages = FakeAsyncIterator() + + mock_aenter = AsyncMock() + mock_aenter.return_value = mock_client + + mock_shim = Mock() + mock_shim.return_value.__aenter__ = mock_aenter + mock_shim.return_value.__aexit__ = AsyncMock() + + with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim): + session = await create_mqtt_session(FAKE_PARAMS) + assert session.connected + + mock_client.subscribe.side_effect = aiomqtt.MqttError + + subscriber1 = Subscriber() + with pytest.raises(MqttSessionException, match="Error subscribing to topic"): + await session.subscribe("topic-1", subscriber1.append) + + assert not subscriber1.messages