Skip to content

Commit cb33def

Browse files
committed
Add exception handling and increased test coverage
1 parent 22b9a3f commit cb33def

File tree

3 files changed

+131
-18
lines changed

3 files changed

+131
-18
lines changed

roborock/mqtt/roborock_session.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import aiomqtt
1818
from aiomqtt import MqttError, TLSParameters
1919

20-
from .. import RoborockException
21-
from .session import MqttParams, MqttSession
20+
from .session import MqttParams, MqttSession, MqttSessionException
2221

2322
_LOGGER = logging.getLogger(__name__)
2423
_MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt")
@@ -71,7 +70,14 @@ async def start(self) -> None:
7170
start_future: asyncio.Future[None] = asyncio.Future()
7271
loop = asyncio.get_event_loop()
7372
self._background_task = loop.create_task(self._run_task(start_future))
74-
await start_future
73+
try:
74+
await start_future
75+
except MqttError as err:
76+
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
77+
except Exception as err:
78+
raise MqttSessionException(f"Unexpected error starting session: {err}") from err
79+
else:
80+
_LOGGER.debug("MQTT session started successfully")
7581

7682
async def close(self) -> None:
7783
"""Cancels the MQTT loop and shutdown the client library."""
@@ -102,14 +108,18 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
102108

103109
await self._process_message_loop(client)
104110

105-
except asyncio.CancelledError:
106-
_LOGGER.debug("MQTT loop was cancelled")
107-
return
108111
except MqttError as err:
109-
_LOGGER.info("MQTT error: %s", err)
110112
if start_future:
113+
_LOGGER.info("MQTT error starting session: %s", err)
111114
start_future.set_exception(err)
112115
return
116+
_LOGGER.info("MQTT error: %s", err)
117+
except asyncio.CancelledError as err:
118+
if start_future:
119+
_LOGGER.debug("MQTT loop was cancelled")
120+
start_future.set_exception(err)
121+
_LOGGER.debug("MQTT loop was cancelled whiel starting")
122+
return
113123
# Catch exceptions to avoid crashing the loop
114124
# and to allow the loop to retry.
115125
except Exception as err:
@@ -118,10 +128,11 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
118128
if "generator didn't stop" in str(err):
119129
_LOGGER.debug("MQTT loop was cancelled")
120130
return
121-
_LOGGER.error("Uncaught error in MQTT session: %s", err)
122131
if start_future:
132+
_LOGGER.error("Uncaught error starting MQTT session: %s", err)
123133
start_future.set_exception(err)
124134
return
135+
_LOGGER.error("Uncaught error during MQTT session: %s", err)
125136

126137
self._healthy = False
127138
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
@@ -150,6 +161,8 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
150161
self._client = client
151162
for topic in self._listeners:
152163
_LOGGER.debug("Re-establising subscription to topic %s", topic)
164+
# TODO: If this fails it will break the whole connection. Make
165+
# this retry again in the background with backoff.
153166
await client.subscribe(topic)
154167

155168
yield client
@@ -158,10 +171,11 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
158171
self._client = None
159172

160173
async def _process_message_loop(self, client: aiomqtt.Client) -> None:
161-
_LOGGER.debug("Processing MQTT messages")
174+
_LOGGER.debug("client=%s", client)
175+
_LOGGER.debug("Processing MQTT messages: %s", client.messages)
162176
async for message in client.messages:
163177
_LOGGER.debug("Received message: %s", message)
164-
for listener in self._listeners.get(message.topic.value) or []:
178+
for listener in self._listeners.get(message.topic.value, []):
165179
try:
166180
listener(message.payload)
167181
except asyncio.CancelledError:
@@ -185,7 +199,10 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
185199
async with self._client_lock:
186200
if self._client:
187201
_LOGGER.debug("Establishing subscription to topic %s", topic)
188-
await self._client.subscribe(topic)
202+
try:
203+
await self._client.subscribe(topic)
204+
except MqttError as err:
205+
raise MqttSessionException(f"Error subscribing to topic: {err}") from err
189206
else:
190207
_LOGGER.debug("Client not connected, will establish subscription later")
191208

@@ -194,11 +211,15 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
194211
async def publish(self, topic: str, message: bytes) -> None:
195212
"""Publish a message on the topic."""
196213
_LOGGER.debug("Sending message to topic %s: %s", topic, message)
214+
client: aiomqtt.Client
197215
async with self._client_lock:
198-
if not self._client:
199-
raise RoborockException("MQTT client not connected")
200-
coro = self._client.publish(topic, message)
201-
await coro
216+
if self._client is None:
217+
raise MqttSessionException("Could not publish message, MQTT client not connected")
218+
client = self._client
219+
try:
220+
await client.publish(topic, message)
221+
except MqttError as err:
222+
raise MqttSessionException(f"Error publishing message: {err}") from err
202223

203224

204225
async def create_mqtt_session(params: MqttParams) -> MqttSession:

roborock/mqtt/session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections.abc import Callable
55
from dataclasses import dataclass
66

7+
from roborock.exceptions import RoborockException
8+
79
DEFAULT_TIMEOUT = 30.0
810

911

@@ -55,3 +57,7 @@ async def publish(self, topic: str, message: bytes) -> None:
5557
@abstractmethod
5658
async def close(self) -> None:
5759
"""Cancels the mqtt loop"""
60+
61+
62+
class MqttSessionException(RoborockException):
63+
""" "Raised when there is an error communicating with MQTT."""

tests/mqtt/test_roborock_session.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from collections.abc import Callable, Generator
55
from queue import Queue
66
from typing import Any
7-
from unittest.mock import patch
7+
from unittest.mock import AsyncMock, Mock, patch
88

9+
import aiomqtt
910
import paho.mqtt.client as mqtt
1011
import pytest
1112

1213
from roborock.mqtt.roborock_session import create_mqtt_session
13-
from roborock.mqtt.session import MqttParams
14+
from roborock.mqtt.session import MqttParams, MqttSessionException
1415
from tests import mqtt_packet
1516
from tests.conftest import FakeSocketHandler
1617

@@ -79,7 +80,11 @@ def push(message: bytes) -> None:
7980

8081

8182
class Subscriber:
82-
"""Mock subscriber class."""
83+
"""Mock subscriber class.
84+
85+
This will capture messages published on the session so the tests can verify
86+
they were received.
87+
"""
8388

8489
def __init__(self) -> None:
8590
"""Initialize the subscriber."""
@@ -102,6 +107,7 @@ async def test_session(push_response: Callable[[bytes], None]) -> None:
102107

103108
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
104109
session = await create_mqtt_session(FAKE_PARAMS)
110+
assert session.connected
105111

106112
push_response(mqtt_packet.gen_suback(mid=1))
107113
subscriber1 = Subscriber()
@@ -130,7 +136,22 @@ async def test_session(push_response: Callable[[bytes], None]) -> None:
130136
push_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored"))
131137
assert subscriber1.messages == [b"12345", b"ABC"]
132138

139+
assert session.connected
133140
await session.close()
141+
assert not session.connected
142+
143+
144+
async def test_session_no_subscribers(push_response: Callable[[bytes], None]) -> None:
145+
"""Test the MQTT session."""
146+
147+
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
148+
push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
149+
push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890"))
150+
session = await create_mqtt_session(FAKE_PARAMS)
151+
assert session.connected
152+
153+
await session.close()
154+
assert not session.connected
134155

135156

136157
async def test_publish_command(push_response: Callable[[bytes], None]) -> None:
@@ -139,4 +160,69 @@ async def test_publish_command(push_response: Callable[[bytes], None]) -> None:
139160
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
140161
session = await create_mqtt_session(FAKE_PARAMS)
141162

163+
push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
142164
await session.publish("topic-1", message=b"payload")
165+
166+
assert session.connected
167+
await session.close()
168+
assert not session.connected
169+
170+
171+
class FakeAsyncIterator:
172+
"""Fake async iterator that waits for messages to arrive, but they never do.
173+
174+
This is used for testing exceptions in other client functions.
175+
"""
176+
177+
def __aiter__(self):
178+
return self
179+
180+
async def __anext__(self) -> None:
181+
"""Iterator that does not generate any messages."""
182+
while True:
183+
await asyncio.sleep(1)
184+
185+
186+
async def test_publish_failure() -> None:
187+
"""Test an MQTT error is received when publishing a message."""
188+
189+
mock_client = AsyncMock()
190+
mock_client.messages = FakeAsyncIterator()
191+
192+
mock_aenter = AsyncMock()
193+
mock_aenter.return_value = mock_client
194+
195+
with patch("roborock.mqtt.roborock_session.aiomqtt.Client.__aenter__", mock_aenter):
196+
session = await create_mqtt_session(FAKE_PARAMS)
197+
assert session.connected
198+
199+
mock_client.publish.side_effect = aiomqtt.MqttError
200+
201+
with pytest.raises(MqttSessionException, match="Error publishing message"):
202+
await session.publish("topic-1", message=b"payload")
203+
204+
205+
async def test_subscribe_failure() -> None:
206+
"""Test an MQTT error while subscribing."""
207+
208+
mock_client = AsyncMock()
209+
mock_client.messages = FakeAsyncIterator()
210+
211+
mock_aenter = AsyncMock()
212+
mock_aenter.return_value = mock_client
213+
214+
mock_shim = Mock()
215+
mock_shim.return_value.__aenter__ = mock_aenter
216+
mock_shim.return_value.__aexit__ = AsyncMock()
217+
218+
with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim):
219+
session = await create_mqtt_session(FAKE_PARAMS)
220+
assert session.connected
221+
222+
mock_client.subscribe.side_effect = aiomqtt.MqttError
223+
224+
subscriber1 = Subscriber()
225+
with pytest.raises(MqttSessionException, match="Error subscribing to topic"):
226+
await session.subscribe("topic-1", subscriber1.append)
227+
228+
assert not subscriber1.messages

0 commit comments

Comments
 (0)