Skip to content

Commit c1bdac0

Browse files
feat: Add a v1 protocol channel that can send messages across MQTT or Local connections, preferring local (#416)
* feat: Add a v1 protocol channel bridging across MQTT/Local channels * chore: Remove whitespace * feat: Fix tests referencing RoborockStateCode * feat: Fix tests reverted by co-pilot * fix: Update error message and add pydoc for exception handling on subscribe * chore(deps): bump click from 8.1.8 to 8.2.1 (#401) Bumps [click](https://github.com/pallets/click) from 8.1.8 to 8.2.1. - [Release notes](https://github.com/pallets/click/releases) - [Changelog](https://github.com/pallets/click/blob/main/CHANGES.rst) - [Commits](pallets/click@8.1.8...8.2.1) --- updated-dependencies: - dependency-name: click dependency-version: 8.2.1 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * fix: Simplify local connection handling * fix: Update pydoc for sending a raw command --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
1 parent 0e681be commit c1bdac0

File tree

12 files changed

+1066
-43
lines changed

12 files changed

+1066
-43
lines changed

roborock/cli.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,16 @@ async def home_data_cache() -> HomeData:
115115
devices = await device_manager.get_devices()
116116
click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}")
117117

118-
click.echo("MQTT session started. Listening for messages...")
118+
click.echo("MQTT session started. Querying devices...")
119+
for device in devices:
120+
try:
121+
status = await device.get_status()
122+
except RoborockException as e:
123+
click.echo(f"Failed to get status for {device.name}: {e}")
124+
else:
125+
click.echo(f"Device {device.name} status: {status.as_dict()}")
126+
127+
click.echo("Listening for messages.")
119128
await asyncio.sleep(duration)
120129

121130
# Close the device manager (this will close all devices and MQTT session)

roborock/devices/device.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,18 @@
99
from collections.abc import Callable
1010
from functools import cached_property
1111

12-
from roborock.containers import HomeDataDevice, HomeDataProduct, UserData
12+
from roborock.containers import (
13+
HomeDataDevice,
14+
HomeDataProduct,
15+
ModelStatus,
16+
S7MaxVStatus,
17+
Status,
18+
UserData,
19+
)
1320
from roborock.roborock_message import RoborockMessage
21+
from roborock.roborock_typing import RoborockCommand
1422

15-
from .mqtt_channel import MqttChannel
23+
from .v1_channel import V1Channel
1624

1725
_LOGGER = logging.getLogger(__name__)
1826

@@ -38,19 +46,18 @@ def __init__(
3846
user_data: UserData,
3947
device_info: HomeDataDevice,
4048
product_info: HomeDataProduct,
41-
mqtt_channel: MqttChannel,
49+
v1_channel: V1Channel,
4250
) -> None:
4351
"""Initialize the RoborockDevice.
4452
45-
The device takes ownership of the MQTT channel for communication with the device.
46-
Use `connect()` to establish the connection, which will set up the MQTT channel
47-
for receiving messages from the device. Use `close()` to unsubscribe from the MQTT
48-
channel.
53+
The device takes ownership of the V1 channel for communication with the device.
54+
Use `connect()` to establish the connection, which will set up the appropriate
55+
protocol channel. Use `close()` to clean up all connections.
4956
"""
5057
self._user_data = user_data
5158
self._device_info = device_info
5259
self._product_info = product_info
53-
self._mqtt_channel = mqtt_channel
60+
self._v1_channel = v1_channel
5461
self._unsub: Callable[[], None] | None = None
5562

5663
@property
@@ -82,27 +89,32 @@ def device_version(self) -> str:
8289
)
8390
return DeviceVersion.UNKNOWN
8491

85-
async def connect(self) -> None:
86-
"""Connect to the device using MQTT.
92+
@property
93+
def is_connected(self) -> bool:
94+
"""Return whether the device is connected."""
95+
return self._v1_channel.is_mqtt_connected or self._v1_channel.is_local_connected
8796

88-
This method will set up the MQTT channel for communication with the device.
89-
"""
97+
async def connect(self) -> None:
98+
"""Connect to the device using the appropriate protocol channel."""
9099
if self._unsub:
91100
raise ValueError("Already connected to the device")
92-
self._unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
101+
self._unsub = await self._v1_channel.subscribe(self._on_message)
102+
_LOGGER.info("Connected to V1 device %s", self.name)
93103

94104
async def close(self) -> None:
95-
"""Close the MQTT connection to the device.
96-
97-
This method will unsubscribe from the MQTT channel and clean up resources.
98-
"""
105+
"""Close all connections to the device."""
99106
if self._unsub:
100107
self._unsub()
101108
self._unsub = None
102109

103-
def _on_mqtt_message(self, message: RoborockMessage) -> None:
104-
"""Handle incoming MQTT messages from the device.
110+
def _on_message(self, message: RoborockMessage) -> None:
111+
"""Handle incoming messages from the device."""
112+
_LOGGER.debug("Received message from device: %s", message)
113+
114+
async def get_status(self) -> Status:
115+
"""Get the current status of the device.
105116
106-
This method should be overridden in subclasses to handle specific device messages.
117+
This is a placeholder command and will likely be changed/moved in the future.
107118
"""
108-
_LOGGER.debug("Received message from device %s: %s", self.duid, message)
119+
status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus)
120+
return await self._v1_channel.send_decoded_command(RoborockCommand.GET_STATUS, response_type=status_type)

roborock/devices/device_manager.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
HomeDataProduct,
1111
UserData,
1212
)
13-
from roborock.devices.device import RoborockDevice
13+
from roborock.devices.device import DeviceVersion, RoborockDevice
1414
from roborock.mqtt.roborock_session import create_mqtt_session
1515
from roborock.mqtt.session import MqttSession
1616
from roborock.protocol import create_mqtt_params
1717
from roborock.web_api import RoborockApiClient
1818

19-
from .mqtt_channel import MqttChannel
19+
from .v1_channel import create_v1_channel
2020

2121
_LOGGER = logging.getLogger(__name__)
2222

@@ -114,8 +114,15 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi)
114114
mqtt_session = await create_mqtt_session(mqtt_params)
115115

116116
def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
117-
mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
118-
return RoborockDevice(user_data, device, product, mqtt_channel)
117+
# Check device version and only support V1 for now
118+
if device.pv != DeviceVersion.V1.value:
119+
raise NotImplementedError(
120+
f"Device {device.name} has version {device.pv}, but only V1 devices "
121+
f"are supported through the unified interface."
122+
)
123+
# Create V1 channel that handles both MQTT and local connections
124+
v1_channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device)
125+
return RoborockDevice(user_data, device, product, v1_channel)
119126

120127
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
121128
await manager.discover_devices()

roborock/devices/local_channel.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def connect(self) -> None:
6464
except OSError as e:
6565
raise RoborockConnectionException(f"Failed to connect to {self._host}:{_PORT}") from e
6666

67-
async def close(self) -> None:
67+
def close(self) -> None:
6868
"""Disconnect from the device."""
6969
if self._transport:
7070
self._transport.close()
@@ -144,3 +144,25 @@ async def send_command(self, message: RoborockMessage, timeout: float = 10.0) ->
144144
async with self._queue_lock:
145145
self._waiting_queue.pop(request_id, None)
146146
raise
147+
148+
149+
# This module provides a factory function to create LocalChannel instances.
150+
#
151+
# TODO: Make a separate LocalSession and use it to manage retries with the host,
152+
# similar to how MqttSession works. For now this is a simple factory function
153+
# for creating channels.
154+
LocalSession = Callable[[str], LocalChannel]
155+
156+
157+
def create_local_session(local_key: str) -> LocalSession:
158+
"""Creates a local session which can create local channels.
159+
160+
This plays a role similar to the MqttSession but is really just a factory
161+
for creating LocalChannel instances with the same local key.
162+
"""
163+
164+
def create_local_channel(host: str) -> LocalChannel:
165+
"""Create a LocalChannel instance for the given host."""
166+
return LocalChannel(host, local_key)
167+
168+
return create_local_channel

roborock/devices/v1_channel.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""V1 Channel for Roborock devices.
2+
3+
This module provides a unified channel interface for V1 protocol devices,
4+
handling both MQTT and local connections with automatic fallback.
5+
"""
6+
7+
import logging
8+
from collections.abc import Callable
9+
from typing import Any, TypeVar
10+
11+
from roborock.containers import HomeDataDevice, NetworkInfo, RoborockBase, UserData
12+
from roborock.exceptions import RoborockException
13+
from roborock.mqtt.session import MqttParams, MqttSession
14+
from roborock.protocols.v1_protocol import (
15+
CommandType,
16+
ParamsType,
17+
SecurityData,
18+
create_mqtt_payload_encoder,
19+
create_security_data,
20+
decode_rpc_response,
21+
encode_local_payload,
22+
)
23+
from roborock.roborock_message import RoborockMessage
24+
from roborock.roborock_typing import RoborockCommand
25+
26+
from .local_channel import LocalChannel, LocalSession, create_local_session
27+
from .mqtt_channel import MqttChannel
28+
29+
_LOGGER = logging.getLogger(__name__)
30+
31+
__all__ = [
32+
"V1Channel",
33+
]
34+
35+
_T = TypeVar("_T", bound=RoborockBase)
36+
37+
38+
class V1Channel:
39+
"""Unified V1 protocol channel with automatic MQTT/local connection handling.
40+
41+
This channel abstracts away the complexity of choosing between MQTT and local
42+
connections, and provides high-level V1 protocol methods. It automatically
43+
handles connection setup, fallback logic, and protocol encoding/decoding.
44+
"""
45+
46+
def __init__(
47+
self,
48+
device_uid: str,
49+
security_data: SecurityData,
50+
mqtt_channel: MqttChannel,
51+
local_session: LocalSession,
52+
) -> None:
53+
"""Initialize the V1Channel.
54+
55+
Args:
56+
mqtt_channel: MQTT channel for cloud communication
57+
local_session: Factory that creates LocalChannels for a hostname.
58+
"""
59+
self._device_uid = device_uid
60+
self._mqtt_channel = mqtt_channel
61+
self._mqtt_payload_encoder = create_mqtt_payload_encoder(security_data)
62+
self._local_session = local_session
63+
self._local_channel: LocalChannel | None = None
64+
self._mqtt_unsub: Callable[[], None] | None = None
65+
self._local_unsub: Callable[[], None] | None = None
66+
self._callback: Callable[[RoborockMessage], None] | None = None
67+
self._networking_info: NetworkInfo | None = None
68+
69+
@property
70+
def is_local_connected(self) -> bool:
71+
"""Return whether local connection is available."""
72+
return self._local_unsub is not None
73+
74+
@property
75+
def is_mqtt_connected(self) -> bool:
76+
"""Return whether MQTT connection is available."""
77+
return self._mqtt_unsub is not None
78+
79+
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
80+
"""Subscribe to all messages from the device.
81+
82+
This will establish MQTT connection first, and also attempt to set up
83+
local connection if possible. Any failures to subscribe to MQTT will raise
84+
a RoborockException. A local connection failure will not raise an exception,
85+
since the local connection is optional.
86+
"""
87+
88+
if self._mqtt_unsub:
89+
raise ValueError("Already connected to the device")
90+
self._callback = callback
91+
92+
# First establish MQTT connection
93+
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
94+
_LOGGER.debug("V1Channel connected to device %s via MQTT", self._device_uid)
95+
96+
# Try to establish an optional local connection as well.
97+
try:
98+
self._local_unsub = await self._local_connect()
99+
except RoborockException as err:
100+
_LOGGER.warning("Could not establish local connection for device %s: %s", self._device_uid, err)
101+
else:
102+
_LOGGER.debug("Local connection established for device %s", self._device_uid)
103+
104+
def unsub() -> None:
105+
"""Unsubscribe from all messages."""
106+
if self._mqtt_unsub:
107+
self._mqtt_unsub()
108+
self._mqtt_unsub = None
109+
if self._local_unsub:
110+
self._local_unsub()
111+
self._local_unsub = None
112+
_LOGGER.debug("Unsubscribed from device %s", self._device_uid)
113+
114+
return unsub
115+
116+
async def _get_networking_info(self) -> NetworkInfo:
117+
"""Retrieve networking information for the device.
118+
119+
This is a cloud only command used to get the local device's IP address.
120+
"""
121+
try:
122+
return await self._send_mqtt_decoded_command(RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo)
123+
except RoborockException as e:
124+
raise RoborockException(f"Network info failed for device {self._device_uid}") from e
125+
126+
async def _local_connect(self) -> Callable[[], None]:
127+
"""Set up local connection if possible."""
128+
_LOGGER.debug("Attempting to connect to local channel for device %s", self._device_uid)
129+
if self._networking_info is None:
130+
self._networking_info = await self._get_networking_info()
131+
host = self._networking_info.ip
132+
_LOGGER.debug("Connecting to local channel at %s", host)
133+
self._local_channel = self._local_session(host)
134+
try:
135+
await self._local_channel.connect()
136+
except RoborockException as e:
137+
self._local_channel = None
138+
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
139+
140+
return await self._local_channel.subscribe(self._on_local_message)
141+
142+
async def send_decoded_command(
143+
self,
144+
method: CommandType,
145+
*,
146+
response_type: type[_T],
147+
params: ParamsType = None,
148+
) -> _T:
149+
"""Send a command using the best available transport.
150+
151+
Will prefer local connection if available, falling back to MQTT.
152+
"""
153+
connection = "local" if self.is_local_connected else "mqtt"
154+
_LOGGER.debug("Sending command (%s): %s, params=%s", connection, method, params)
155+
if self._local_channel:
156+
return await self._send_local_decoded_command(method, response_type=response_type, params=params)
157+
return await self._send_mqtt_decoded_command(method, response_type=response_type, params=params)
158+
159+
async def _send_mqtt_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]:
160+
"""Send a raw command and return a raw unparsed response."""
161+
message = self._mqtt_payload_encoder(method, params)
162+
_LOGGER.debug("Sending MQTT message for device %s: %s", self._device_uid, message)
163+
response = await self._mqtt_channel.send_command(message)
164+
return decode_rpc_response(response)
165+
166+
async def _send_mqtt_decoded_command(
167+
self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None
168+
) -> _T:
169+
"""Send a command over MQTT and decode the response."""
170+
decoded_response = await self._send_mqtt_raw_command(method, params)
171+
return response_type.from_dict(decoded_response)
172+
173+
async def _send_local_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]:
174+
"""Send a raw command over local connection."""
175+
if not self._local_channel:
176+
raise RoborockException("Local channel is not connected")
177+
178+
message = encode_local_payload(method, params)
179+
_LOGGER.debug("Sending local message for device %s: %s", self._device_uid, message)
180+
response = await self._local_channel.send_command(message)
181+
return decode_rpc_response(response)
182+
183+
async def _send_local_decoded_command(
184+
self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None
185+
) -> _T:
186+
"""Send a command over local connection and decode the response."""
187+
if not self._local_channel:
188+
raise RoborockException("Local channel is not connected")
189+
decoded_response = await self._send_local_raw_command(method, params)
190+
return response_type.from_dict(decoded_response)
191+
192+
def _on_mqtt_message(self, message: RoborockMessage) -> None:
193+
"""Handle incoming MQTT messages."""
194+
_LOGGER.debug("V1Channel received MQTT message from device %s: %s", self._device_uid, message)
195+
if self._callback:
196+
self._callback(message)
197+
198+
def _on_local_message(self, message: RoborockMessage) -> None:
199+
"""Handle incoming local messages."""
200+
_LOGGER.debug("V1Channel received local message from device %s: %s", self._device_uid, message)
201+
if self._callback:
202+
self._callback(message)
203+
204+
205+
def create_v1_channel(
206+
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
207+
) -> V1Channel:
208+
"""Create a V1Channel for the given device."""
209+
security_data = create_security_data(user_data.rriot)
210+
mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
211+
local_session = create_local_session(device.local_key)
212+
return V1Channel(device.duid, security_data, mqtt_channel, local_session=local_session)

0 commit comments

Comments
 (0)