Skip to content

Commit 54547d8

Browse files
authored
feat: Add a CLI for exercising the asyncio MQTT session (#396)
* feat: Add a CLI for exercising the asyncio MQTT session * feat: Share mqtt url parsing code with original client * feat: remove unused import * feat: Update bytes dump * feat: Fix lint error
1 parent b31ce69 commit 54547d8

File tree

4 files changed

+92
-20
lines changed

4 files changed

+92
-20
lines changed

roborock/cli.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
import logging
56
from pathlib import Path
@@ -12,7 +13,8 @@
1213

1314
from roborock import RoborockException
1415
from roborock.containers import DeviceData, HomeDataProduct, LoginData
15-
from roborock.protocol import MessageParser
16+
from roborock.mqtt.roborock_session import create_mqtt_session
17+
from roborock.protocol import MessageParser, create_mqtt_params
1618
from roborock.util import run_sync
1719
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
1820
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
@@ -45,7 +47,8 @@ def validate(self):
4547
if self._login_data is None:
4648
raise RoborockException("You must login first")
4749

48-
def login_data(self):
50+
def login_data(self) -> LoginData:
51+
"""Get the login data."""
4952
self.validate()
5053
return self._login_data
5154

@@ -90,6 +93,54 @@ async def login(ctx, email, password):
9093
context.update(LoginData(user_data=user_data, email=email))
9194

9295

96+
@click.command()
97+
@click.pass_context
98+
@click.option("--duration", default=10, help="Duration to run the MQTT session in seconds")
99+
@run_sync()
100+
async def session(ctx, duration: int):
101+
context: RoborockContext = ctx.obj
102+
login_data = context.login_data()
103+
104+
# Discovery devices if not already available
105+
if not login_data.home_data:
106+
await _discover(ctx)
107+
login_data = context.login_data()
108+
if not login_data.home_data or not login_data.home_data.devices:
109+
raise RoborockException("Unable to discover devices")
110+
111+
all_devices = login_data.home_data.devices + login_data.home_data.received_devices
112+
click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}")
113+
114+
rriot = login_data.user_data.rriot
115+
params = create_mqtt_params(rriot)
116+
117+
mqtt_session = await create_mqtt_session(params)
118+
click.echo("Starting MQTT session...")
119+
if not mqtt_session.connected:
120+
raise RoborockException("Failed to connect to MQTT broker")
121+
122+
def on_message(bytes: bytes):
123+
"""Callback function to handle incoming MQTT messages."""
124+
# Decode the first 20 bytes of the message for display
125+
bytes = bytes[:20]
126+
127+
click.echo(f"Received message: {bytes}...")
128+
129+
unsubs = []
130+
for device in all_devices:
131+
device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}"
132+
unsub = await mqtt_session.subscribe(device_topic, on_message)
133+
unsubs.append(unsub)
134+
135+
click.echo("MQTT session started. Listening for messages...")
136+
await asyncio.sleep(duration)
137+
138+
click.echo("Stopping MQTT session...")
139+
for unsub in unsubs:
140+
unsub()
141+
await mqtt_session.close()
142+
143+
93144
async def _discover(ctx):
94145
context: RoborockContext = ctx.obj
95146
login_data = context.login_data()
@@ -264,6 +315,7 @@ def on_package(packet: Packet):
264315
cli.add_command(status)
265316
cli.add_command(command)
266317
cli.add_command(parser)
318+
cli.add_command(session)
267319

268320

269321
def main():

roborock/cloud_api.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66
from abc import ABC
77
from asyncio import Lock
88
from typing import Any
9-
from urllib.parse import urlparse
109

1110
import paho.mqtt.client as mqtt
1211

1312
from .api import KEEPALIVE, RoborockClient
1413
from .containers import DeviceData, UserData
1514
from .exceptions import RoborockException, VacuumError
16-
from .protocol import Decoder, Encoder, create_mqtt_decoder, create_mqtt_encoder, md5hex
15+
from .protocol import (
16+
Decoder,
17+
Encoder,
18+
create_mqtt_decoder,
19+
create_mqtt_encoder,
20+
create_mqtt_params,
21+
)
1722
from .roborock_future import RoborockFuture
1823

1924
_LOGGER = logging.getLogger(__name__)
@@ -53,25 +58,20 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
5358
if rriot is None:
5459
raise RoborockException("Got no rriot data from user_data")
5560
RoborockClient.__init__(self, device_info)
61+
mqtt_params = create_mqtt_params(rriot)
5662
self._mqtt_user = rriot.u
57-
self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.k)[2:10]
58-
url = urlparse(rriot.r.m)
59-
if not isinstance(url.hostname, str):
60-
raise RoborockException("Url parsing returned an invalid hostname")
61-
self._mqtt_host = str(url.hostname)
62-
self._mqtt_port = url.port
63-
self._mqtt_ssl = url.scheme == "ssl"
63+
self._hashed_user = mqtt_params.username
64+
self._mqtt_host = mqtt_params.host
65+
self._mqtt_port = mqtt_params.port
6466

6567
self._mqtt_client = _Mqtt()
6668
self._mqtt_client.on_connect = self._mqtt_on_connect
6769
self._mqtt_client.on_message = self._mqtt_on_message
6870
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect
69-
if self._mqtt_ssl:
71+
if mqtt_params.tls:
7072
self._mqtt_client.tls_set()
7173

72-
self._mqtt_password = rriot.s
73-
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:]
74-
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
74+
self._mqtt_client.username_pw_set(mqtt_params.username, mqtt_params.password)
7575
self._waiting_queue: dict[int, RoborockFuture] = {}
7676
self._mutex = Lock()
7777
self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key)

roborock/mqtt/roborock_session.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
116116
_LOGGER.info("MQTT error: %s", err)
117117
except asyncio.CancelledError as err:
118118
if start_future:
119-
_LOGGER.debug("MQTT loop was cancelled")
119+
_LOGGER.debug("MQTT loop was cancelled while starting")
120120
start_future.set_exception(err)
121-
_LOGGER.debug("MQTT loop was cancelled while starting")
121+
_LOGGER.debug("MQTT loop was cancelled")
122122
return
123123
# Catch exceptions to avoid crashing the loop
124124
# and to allow the loop to retry.
@@ -171,8 +171,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
171171
self._client = None
172172

173173
async def _process_message_loop(self, client: aiomqtt.Client) -> None:
174-
_LOGGER.debug("client=%s", client)
175-
_LOGGER.debug("Processing MQTT messages: %s", client.messages)
174+
_LOGGER.debug("Processing MQTT messages")
176175
async for message in client.messages:
177176
_LOGGER.debug("Received message: %s", message)
178177
for listener in self._listeners.get(message.topic.value, []):

roborock/protocol.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
from asyncio import BaseTransport, Lock
1010
from collections.abc import Callable
11+
from urllib.parse import urlparse
1112

1213
from construct import ( # type: ignore
1314
Bytes,
@@ -30,7 +31,9 @@
3031
from Crypto.Cipher import AES
3132
from Crypto.Util.Padding import pad, unpad
3233

33-
from roborock import BroadcastMessage, RoborockException
34+
from roborock.containers import BroadcastMessage, RRiot
35+
from roborock.exceptions import RoborockException
36+
from roborock.mqtt.session import MqttParams
3437
from roborock.roborock_message import RoborockMessage
3538

3639
_LOGGER = logging.getLogger(__name__)
@@ -361,6 +364,24 @@ def build(
361364
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)
362365

363366

367+
def create_mqtt_params(rriot: RRiot) -> MqttParams:
368+
"""Return the MQTT parameters for this user."""
369+
url = urlparse(rriot.r.m)
370+
if not isinstance(url.hostname, str):
371+
raise RoborockException(f"Url parsing '{rriot.r.m}' returned an invalid hostname")
372+
if not url.port:
373+
raise RoborockException(f"Url parsing '{rriot.r.m}' returned an invalid port")
374+
hashed_user = md5hex(rriot.u + ":" + rriot.k)[2:10]
375+
hashed_password = md5hex(rriot.s + ":" + rriot.k)[16:]
376+
return MqttParams(
377+
host=str(url.hostname),
378+
port=url.port,
379+
tls=(url.scheme == "ssl"),
380+
username=hashed_user,
381+
password=hashed_password,
382+
)
383+
384+
364385
Decoder = Callable[[bytes], list[RoborockMessage]]
365386
Encoder = Callable[[RoborockMessage], bytes]
366387

0 commit comments

Comments
 (0)