Skip to content

Commit f7d1a55

Browse files
allenporterLash-L
andauthored
feat: Support both a01 and v1 device types with traits (#425)
* feat: Support both a01 and v1 device types with traits * fix: add safety check for trait creation * feat: Update cli with v1 status trait * chore: address code review feedback * chore: Update roborock/devices/v1_channel.py Co-authored-by: Luke Lashley <[email protected]> * chore: Revert encode_mqtt_payload typing change * fix: update mqtt payload encoding signature --------- Co-authored-by: Luke Lashley <[email protected]>
1 parent 636268d commit f7d1a55

File tree

16 files changed

+339
-94
lines changed

16 files changed

+339
-94
lines changed

roborock/cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,11 @@ async def home_data_cache() -> HomeData:
117117

118118
click.echo("MQTT session started. Querying devices...")
119119
for device in devices:
120+
if not (status_trait := device.traits.get("status")):
121+
click.echo(f"Device {device.name} does not have a status trait")
122+
continue
120123
try:
121-
status = await device.get_status()
124+
status = await status_trait.get_status()
122125
except RoborockException as e:
123126
click.echo(f"Failed to get status for {device.name}: {e}")
124127
else:

roborock/devices/a01_channel.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Thin wrapper around the MQTT channel for Roborock A01 devices."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import Any, overload
7+
8+
from roborock.protocols.a01_protocol import (
9+
decode_rpc_response,
10+
encode_mqtt_payload,
11+
)
12+
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
13+
14+
from .mqtt_channel import MqttChannel
15+
16+
_LOGGER = logging.getLogger(__name__)
17+
18+
19+
@overload
20+
async def send_decoded_command(
21+
mqtt_channel: MqttChannel,
22+
params: dict[RoborockDyadDataProtocol, Any],
23+
) -> dict[RoborockDyadDataProtocol, Any]:
24+
...
25+
26+
27+
@overload
28+
async def send_decoded_command(
29+
mqtt_channel: MqttChannel,
30+
params: dict[RoborockZeoProtocol, Any],
31+
) -> dict[RoborockZeoProtocol, Any]:
32+
...
33+
34+
35+
async def send_decoded_command(
36+
mqtt_channel: MqttChannel,
37+
params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any],
38+
) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]:
39+
"""Send a command on the MQTT channel and get a decoded response."""
40+
_LOGGER.debug("Sending MQTT command: %s", params)
41+
roborock_message = encode_mqtt_payload(params)
42+
response = await mqtt_channel.send_message(roborock_message)
43+
return decode_rpc_response(response) # type: ignore[return-value]

roborock/devices/channel.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Low-level interface for connections to Roborock devices."""
2+
3+
import logging
4+
from collections.abc import Callable
5+
from typing import Protocol
6+
7+
from roborock.roborock_message import RoborockMessage
8+
9+
_LOGGER = logging.getLogger(__name__)
10+
11+
12+
class Channel(Protocol):
13+
"""A generic channel for establishing a connection with a Roborock device.
14+
15+
Individual channel implementations have their own methods for speaking to
16+
the device that hide some of the protocol specific complexity, but they
17+
are still specialized for the device type and protocol.
18+
"""
19+
20+
@property
21+
def is_connected(self) -> bool:
22+
"""Return true if the channel is connected."""
23+
...
24+
25+
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
26+
"""Subscribe to messages from the device."""
27+
...

roborock/devices/device.py

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,101 +4,72 @@
44
until the API is stable.
55
"""
66

7-
import enum
87
import logging
9-
from collections.abc import Callable
10-
from functools import cached_property
11-
12-
from roborock.containers import (
13-
HomeDataDevice,
14-
HomeDataProduct,
15-
ModelStatus,
16-
S7MaxVStatus,
17-
Status,
18-
UserData,
19-
)
8+
from abc import ABC
9+
from collections.abc import Callable, Mapping
10+
from types import MappingProxyType
11+
12+
from roborock.containers import HomeDataDevice
2013
from roborock.roborock_message import RoborockMessage
21-
from roborock.roborock_typing import RoborockCommand
2214

23-
from .v1_channel import V1Channel
15+
from .channel import Channel
16+
from .traits.trait import Trait
2417

2518
_LOGGER = logging.getLogger(__name__)
2619

2720
__all__ = [
2821
"RoborockDevice",
29-
"DeviceVersion",
3022
]
3123

3224

33-
class DeviceVersion(enum.StrEnum):
34-
"""Enum for device versions."""
35-
36-
V1 = "1.0"
37-
A01 = "A01"
38-
UNKNOWN = "unknown"
39-
25+
class RoborockDevice(ABC):
26+
"""A generic channel for establishing a connection with a Roborock device.
4027
41-
class RoborockDevice:
42-
"""Unified Roborock device class with automatic connection setup."""
28+
Individual channel implementations have their own methods for speaking to
29+
the device that hide some of the protocol specific complexity, but they
30+
are still specialized for the device type and protocol.
31+
"""
4332

4433
def __init__(
4534
self,
46-
user_data: UserData,
4735
device_info: HomeDataDevice,
48-
product_info: HomeDataProduct,
49-
v1_channel: V1Channel,
36+
channel: Channel,
37+
traits: list[Trait],
5038
) -> None:
5139
"""Initialize the RoborockDevice.
5240
53-
The device takes ownership of the V1 channel for communication with the device.
41+
The device takes ownership of the channel for communication with the device.
5442
Use `connect()` to establish the connection, which will set up the appropriate
5543
protocol channel. Use `close()` to clean up all connections.
5644
"""
57-
self._user_data = user_data
58-
self._device_info = device_info
59-
self._product_info = product_info
60-
self._v1_channel = v1_channel
45+
self._duid = device_info.duid
46+
self._name = device_info.name
47+
self._channel = channel
6148
self._unsub: Callable[[], None] | None = None
49+
self._trait_map = {trait.name: trait for trait in traits}
50+
if len(self._trait_map) != len(traits):
51+
raise ValueError("Duplicate trait names found in traits list")
6252

6353
@property
6454
def duid(self) -> str:
6555
"""Return the device unique identifier (DUID)."""
66-
return self._device_info.duid
56+
return self._duid
6757

6858
@property
6959
def name(self) -> str:
7060
"""Return the device name."""
71-
return self._device_info.name
72-
73-
@cached_property
74-
def device_version(self) -> str:
75-
"""Return the device version.
76-
77-
At the moment this is a simple check against the product version (pv) of the device
78-
and used as a placeholder for upcoming functionality for devices that will behave
79-
differently based on the version and capabilities.
80-
"""
81-
if self._device_info.pv == DeviceVersion.V1.value:
82-
return DeviceVersion.V1
83-
elif self._device_info.pv == DeviceVersion.A01.value:
84-
return DeviceVersion.A01
85-
_LOGGER.warning(
86-
"Unknown device version %s for device %s, using default UNKNOWN",
87-
self._device_info.pv,
88-
self._device_info.name,
89-
)
90-
return DeviceVersion.UNKNOWN
61+
return self._name
9162

9263
@property
9364
def is_connected(self) -> bool:
9465
"""Return whether the device is connected."""
95-
return self._v1_channel.is_mqtt_connected or self._v1_channel.is_local_connected
66+
return self._channel.is_connected
9667

9768
async def connect(self) -> None:
9869
"""Connect to the device using the appropriate protocol channel."""
9970
if self._unsub:
10071
raise ValueError("Already connected to the device")
101-
self._unsub = await self._v1_channel.subscribe(self._on_message)
72+
self._unsub = await self._channel.subscribe(self._on_message)
10273
_LOGGER.info("Connected to V1 device %s", self.name)
10374

10475
async def close(self) -> None:
@@ -111,10 +82,7 @@ def _on_message(self, message: RoborockMessage) -> None:
11182
"""Handle incoming messages from the device."""
11283
_LOGGER.debug("Received message from device: %s", message)
11384

114-
async def get_status(self) -> Status:
115-
"""Get the current status of the device.
116-
117-
This is a placeholder command and will likely be changed/moved in the future.
118-
"""
119-
status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus)
120-
return await self._v1_channel.rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type)
85+
@property
86+
def traits(self) -> Mapping[str, Trait]:
87+
"""Return the traits of the device."""
88+
return MappingProxyType(self._trait_map)

roborock/devices/device_manager.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
"""Module for discovering Roborock devices."""
22

33
import asyncio
4+
import enum
45
import logging
56
from collections.abc import Awaitable, Callable
67

8+
from roborock.code_mappings import RoborockCategory
79
from roborock.containers import (
810
HomeData,
911
HomeDataDevice,
1012
HomeDataProduct,
1113
UserData,
1214
)
13-
from roborock.devices.device import DeviceVersion, RoborockDevice
15+
from roborock.devices.device import RoborockDevice
1416
from roborock.mqtt.roborock_session import create_mqtt_session
1517
from roborock.mqtt.session import MqttSession
1618
from roborock.protocol import create_mqtt_params
1719
from roborock.web_api import RoborockApiClient
1820

21+
from .channel import Channel
22+
from .mqtt_channel import create_mqtt_channel
23+
from .traits.dyad import DyadApi
24+
from .traits.status import StatusTrait
25+
from .traits.trait import Trait
26+
from .traits.zeo import ZeoApi
1927
from .v1_channel import create_v1_channel
2028

2129
_LOGGER = logging.getLogger(__name__)
@@ -33,6 +41,14 @@
3341
DeviceCreator = Callable[[HomeDataDevice, HomeDataProduct], RoborockDevice]
3442

3543

44+
class DeviceVersion(enum.StrEnum):
45+
"""Enum for device versions."""
46+
47+
V1 = "1.0"
48+
A01 = "A01"
49+
UNKNOWN = "unknown"
50+
51+
3652
class DeviceManager:
3753
"""Central manager for Roborock device discovery and connections."""
3854

@@ -114,15 +130,25 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi)
114130
mqtt_session = await create_mqtt_session(mqtt_params)
115131

116132
def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
117-
# Check device version and only support V1 for now
118-
if device.pv != DeviceVersion.V1.value:
119-
raise NotImplementedError(
120-
f"Device {device.name} has version {device.pv}, but only V1 devices "
121-
f"are supported through the unified interface."
122-
)
123-
# Create V1 channel that handles both MQTT and local connections
124-
v1_channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device)
125-
return RoborockDevice(user_data, device, product, v1_channel)
133+
channel: Channel
134+
traits: list[Trait] = []
135+
# TODO: Define a registration mechanism/factory for v1 traits
136+
match device.pv:
137+
case DeviceVersion.V1:
138+
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device)
139+
traits.append(StatusTrait(product, channel.rpc_channel))
140+
case DeviceVersion.A01:
141+
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
142+
match product.category:
143+
case RoborockCategory.WET_DRY_VAC:
144+
traits.append(DyadApi(mqtt_channel))
145+
case RoborockCategory.WASHING_MACHINE:
146+
traits.append(ZeoApi(mqtt_channel))
147+
case _:
148+
raise NotImplementedError(f"Device {device.name} has unsupported category {product.category}")
149+
case _:
150+
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
151+
return RoborockDevice(device, channel, traits)
126152

127153
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
128154
await manager.discover_devices()

roborock/devices/local_channel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
1111
from roborock.roborock_message import RoborockMessage
1212

13+
from .channel import Channel
14+
1315
_LOGGER = logging.getLogger(__name__)
1416
_PORT = 58867
1517

@@ -30,7 +32,7 @@ def connection_lost(self, exc: Exception | None) -> None:
3032
self.connection_lost_cb(exc)
3133

3234

33-
class LocalChannel:
35+
class LocalChannel(Channel):
3436
"""Simple RPC-style channel for communicating with a device over a local network.
3537
3638
Handles request/response correlation and timeouts, but leaves message

roborock/devices/mqtt_channel.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@
55
from collections.abc import Callable
66
from json import JSONDecodeError
77

8-
from roborock.containers import RRiot
8+
from roborock.containers import HomeDataDevice, RRiot, UserData
99
from roborock.exceptions import RoborockException
1010
from roborock.mqtt.session import MqttParams, MqttSession
1111
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
1212
from roborock.roborock_message import RoborockMessage
1313

14+
from .channel import Channel
15+
1416
_LOGGER = logging.getLogger(__name__)
1517

1618

17-
class MqttChannel:
19+
class MqttChannel(Channel):
1820
"""Simple RPC-style channel for communicating with a device over MQTT.
1921
2022
Handles request/response correlation and timeouts, but leaves message
@@ -33,6 +35,12 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot:
3335
self._decoder = create_mqtt_decoder(local_key)
3436
self._encoder = create_mqtt_encoder(local_key)
3537
self._queue_lock = asyncio.Lock()
38+
self._mqtt_unsub: Callable[[], None] | None = None
39+
40+
@property
41+
def is_connected(self) -> bool:
42+
"""Return true if the channel is connected."""
43+
return (self._mqtt_unsub is not None) and self._mqtt_session.connected
3644

3745
@property
3846
def _publish_topic(self) -> str:
@@ -67,7 +75,14 @@ def message_handler(payload: bytes) -> None:
6775
except Exception as e:
6876
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
6977

70-
return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
78+
self._mqtt_unsub = await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
79+
80+
def unsub_wrapper() -> None:
81+
if self._mqtt_unsub is not None:
82+
self._mqtt_unsub()
83+
self._mqtt_unsub = None
84+
85+
return unsub_wrapper
7186

7287
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
7388
"""Resolve waiting future with proper locking."""
@@ -113,3 +128,10 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) ->
113128
async with self._queue_lock:
114129
self._waiting_queue.pop(request_id, None)
115130
raise
131+
132+
133+
def create_mqtt_channel(
134+
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
135+
) -> MqttChannel:
136+
"""Create a V1Channel for the given device."""
137+
return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)

0 commit comments

Comments
 (0)