From f541de5f7af67fc08a14c4993537dc946acfc219 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Thu, 7 Aug 2025 07:53:08 -0700 Subject: [PATCH 1/7] feat: Support both a01 and v1 device types with traits --- roborock/devices/a01_channel.py | 43 +++++++++ roborock/devices/channel.py | 27 ++++++ roborock/devices/device.py | 88 ++++++------------- roborock/devices/device_manager.py | 46 +++++++--- roborock/devices/local_channel.py | 4 +- roborock/devices/mqtt_channel.py | 28 +++++- roborock/devices/traits/dyad.py | 36 ++++++++ roborock/devices/traits/status.py | 48 ++++++++++ roborock/devices/traits/trait.py | 10 +++ roborock/devices/traits/zeo.py | 36 ++++++++ roborock/devices/v1_channel.py | 8 +- roborock/protocols/a01_protocol.py | 4 +- tests/devices/test_device_manager.py | 3 - tests/devices/test_mqtt_channel.py | 4 +- .../{test_device.py => test_v1_device.py} | 42 ++++++--- 15 files changed, 335 insertions(+), 92 deletions(-) create mode 100644 roborock/devices/a01_channel.py create mode 100644 roborock/devices/channel.py create mode 100644 roborock/devices/traits/dyad.py create mode 100644 roborock/devices/traits/status.py create mode 100644 roborock/devices/traits/trait.py create mode 100644 roborock/devices/traits/zeo.py rename tests/devices/{test_device.py => test_v1_device.py} (57%) diff --git a/roborock/devices/a01_channel.py b/roborock/devices/a01_channel.py new file mode 100644 index 00000000..7f73201b --- /dev/null +++ b/roborock/devices/a01_channel.py @@ -0,0 +1,43 @@ +"""Thin wrapper around the MQTT channel for Roborock A01 devices.""" + +from __future__ import annotations + +import logging +from typing import Any, overload + +from roborock.protocols.a01_protocol import ( + decode_rpc_response, + encode_mqtt_payload, +) +from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol + +from .mqtt_channel import MqttChannel + +_LOGGER = logging.getLogger(__name__) + + +@overload +async def send_decoded_command( + mqtt_channel: MqttChannel, + params: dict[RoborockDyadDataProtocol, Any], +) -> dict[RoborockDyadDataProtocol, Any]: + ... + + +@overload +async def send_decoded_command( + mqtt_channel: MqttChannel, + params: dict[RoborockZeoProtocol, Any], +) -> dict[RoborockZeoProtocol, Any]: + ... + + +async def send_decoded_command( + mqtt_channel: MqttChannel, + params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any], +) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]: + """Send a command on the MQTT channel and get a decoded response.""" + _LOGGER.debug("Sending MQTT command: %s", params) + roborock_message = encode_mqtt_payload(params) + response = await mqtt_channel.send_message(roborock_message) + return decode_rpc_response(response) # type: ignore[return-value] diff --git a/roborock/devices/channel.py b/roborock/devices/channel.py new file mode 100644 index 00000000..5474ad35 --- /dev/null +++ b/roborock/devices/channel.py @@ -0,0 +1,27 @@ +"""Low-level interface for connections to Roborock devices.""" + +import logging +from collections.abc import Callable +from typing import Protocol + +from roborock.roborock_message import RoborockMessage + +_LOGGER = logging.getLogger(__name__) + + +class Channel(Protocol): + """A generic channel for establishing a connection with a Roborock device. + + Individual channel implementations have their own methods for speaking to + the device that hide some of the protocol specific complexity, but they + are still specialized for the device type and protocol. + """ + + @property + def is_connected(self) -> bool: + """Return true if the channel is connected.""" + ... + + async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: + """Subscribe to messages from the device.""" + ... diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 623e9e68..157b3bca 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -4,49 +4,37 @@ until the API is stable. """ -import enum import logging -from collections.abc import Callable -from functools import cached_property - -from roborock.containers import ( - HomeDataDevice, - HomeDataProduct, - ModelStatus, - S7MaxVStatus, - Status, - UserData, -) +from abc import ABC +from collections.abc import Callable, Mapping +from types import MappingProxyType + +from roborock.containers import HomeDataDevice from roborock.roborock_message import RoborockMessage -from roborock.roborock_typing import RoborockCommand -from .v1_channel import V1Channel +from .channel import Channel +from .traits.trait import Trait _LOGGER = logging.getLogger(__name__) __all__ = [ "RoborockDevice", - "DeviceVersion", ] -class DeviceVersion(enum.StrEnum): - """Enum for device versions.""" - - V1 = "1.0" - A01 = "A01" - UNKNOWN = "unknown" - +class RoborockDevice(ABC): + """A generic channel for establishing a connection with a Roborock device. -class RoborockDevice: - """Unified Roborock device class with automatic connection setup.""" + Individual channel implementations have their own methods for speaking to + the device that hide some of the protocol specific complexity, but they + are still specialized for the device type and protocol. + """ def __init__( self, - user_data: UserData, device_info: HomeDataDevice, - product_info: HomeDataProduct, - v1_channel: V1Channel, + channel: Channel, + traits: list[Trait], ) -> None: """Initialize the RoborockDevice. @@ -54,51 +42,32 @@ def __init__( Use `connect()` to establish the connection, which will set up the appropriate protocol channel. Use `close()` to clean up all connections. """ - self._user_data = user_data - self._device_info = device_info - self._product_info = product_info - self._v1_channel = v1_channel + self._duid = device_info.duid + self._name = device_info.name + self._channel = channel self._unsub: Callable[[], None] | None = None + self._trait_map = {trait.name: trait for trait in traits} @property def duid(self) -> str: """Return the device unique identifier (DUID).""" - return self._device_info.duid + return self._duid @property def name(self) -> str: """Return the device name.""" - return self._device_info.name - - @cached_property - def device_version(self) -> str: - """Return the device version. - - At the moment this is a simple check against the product version (pv) of the device - and used as a placeholder for upcoming functionality for devices that will behave - differently based on the version and capabilities. - """ - if self._device_info.pv == DeviceVersion.V1.value: - return DeviceVersion.V1 - elif self._device_info.pv == DeviceVersion.A01.value: - return DeviceVersion.A01 - _LOGGER.warning( - "Unknown device version %s for device %s, using default UNKNOWN", - self._device_info.pv, - self._device_info.name, - ) - return DeviceVersion.UNKNOWN + return self._name @property def is_connected(self) -> bool: """Return whether the device is connected.""" - return self._v1_channel.is_mqtt_connected or self._v1_channel.is_local_connected + return self._channel.is_connected async def connect(self) -> None: """Connect to the device using the appropriate protocol channel.""" if self._unsub: raise ValueError("Already connected to the device") - self._unsub = await self._v1_channel.subscribe(self._on_message) + self._unsub = await self._channel.subscribe(self._on_message) _LOGGER.info("Connected to V1 device %s", self.name) async def close(self) -> None: @@ -111,10 +80,7 @@ def _on_message(self, message: RoborockMessage) -> None: """Handle incoming messages from the device.""" _LOGGER.debug("Received message from device: %s", message) - async def get_status(self) -> Status: - """Get the current status of the device. - - This is a placeholder command and will likely be changed/moved in the future. - """ - status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus) - return await self._v1_channel.rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type) + @property + def traits(self) -> Mapping[str, Trait]: + """Return the traits of the device.""" + return MappingProxyType(self._trait_map) diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 5f0d8cc1..b9276340 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -1,21 +1,29 @@ """Module for discovering Roborock devices.""" import asyncio +import enum import logging from collections.abc import Awaitable, Callable +from roborock.code_mappings import RoborockCategory from roborock.containers import ( HomeData, HomeDataDevice, HomeDataProduct, UserData, ) -from roborock.devices.device import DeviceVersion, RoborockDevice +from roborock.devices.device import RoborockDevice from roborock.mqtt.roborock_session import create_mqtt_session from roborock.mqtt.session import MqttSession from roborock.protocol import create_mqtt_params from roborock.web_api import RoborockApiClient +from .channel import Channel +from .mqtt_channel import create_mqtt_channel +from .traits.dyad import DyadApi +from .traits.status import StatusTrait +from .traits.trait import Trait +from .traits.zeo import ZeoApi from .v1_channel import create_v1_channel _LOGGER = logging.getLogger(__name__) @@ -33,6 +41,14 @@ DeviceCreator = Callable[[HomeDataDevice, HomeDataProduct], RoborockDevice] +class DeviceVersion(enum.StrEnum): + """Enum for device versions.""" + + V1 = "1.0" + A01 = "A01" + UNKNOWN = "unknown" + + class DeviceManager: """Central manager for Roborock device discovery and connections.""" @@ -114,15 +130,25 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi) mqtt_session = await create_mqtt_session(mqtt_params) def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice: - # Check device version and only support V1 for now - if device.pv != DeviceVersion.V1.value: - raise NotImplementedError( - f"Device {device.name} has version {device.pv}, but only V1 devices " - f"are supported through the unified interface." - ) - # Create V1 channel that handles both MQTT and local connections - v1_channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device) - return RoborockDevice(user_data, device, product, v1_channel) + channel: Channel + traits: list[Trait] = [] + # TODO: Define a registration mechanism/factory for v1 traits + match device.pv: + case DeviceVersion.V1: + channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device) + traits.append(StatusTrait(product, channel.rpc_channel)) + case DeviceVersion.A01: + mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device) + match product.category: + case RoborockCategory.WET_DRY_VAC: + traits.append(DyadApi(mqtt_channel)) + case RoborockCategory.WASHING_MACHINE: + traits.append(ZeoApi(mqtt_channel)) + case _: + raise NotImplementedError(f"Device {device.name} has unsupported category {product.category}") + case _: + raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}") + return RoborockDevice(device, channel, traits) manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session) await manager.discover_devices() diff --git a/roborock/devices/local_channel.py b/roborock/devices/local_channel.py index c4bb20ea..13906893 100644 --- a/roborock/devices/local_channel.py +++ b/roborock/devices/local_channel.py @@ -10,6 +10,8 @@ from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder from roborock.roborock_message import RoborockMessage +from .channel import Channel + _LOGGER = logging.getLogger(__name__) _PORT = 58867 @@ -30,7 +32,7 @@ def connection_lost(self, exc: Exception | None) -> None: self.connection_lost_cb(exc) -class LocalChannel: +class LocalChannel(Channel): """Simple RPC-style channel for communicating with a device over a local network. Handles request/response correlation and timeouts, but leaves message diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index eb147436..c7be8c12 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -5,16 +5,18 @@ from collections.abc import Callable from json import JSONDecodeError -from roborock.containers import RRiot +from roborock.containers import HomeDataDevice, RRiot, UserData from roborock.exceptions import RoborockException from roborock.mqtt.session import MqttParams, MqttSession from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder from roborock.roborock_message import RoborockMessage +from .channel import Channel + _LOGGER = logging.getLogger(__name__) -class MqttChannel: +class MqttChannel(Channel): """Simple RPC-style channel for communicating with a device over MQTT. Handles request/response correlation and timeouts, but leaves message @@ -33,6 +35,12 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: self._decoder = create_mqtt_decoder(local_key) self._encoder = create_mqtt_encoder(local_key) self._queue_lock = asyncio.Lock() + self._mqtt_unsub: Callable[[], None] | None = None + + @property + def is_connected(self) -> bool: + """Return true if the channel is connected.""" + return (self._mqtt_unsub is not None) and self._mqtt_session.connected @property def _publish_topic(self) -> str: @@ -67,7 +75,14 @@ def message_handler(payload: bytes) -> None: except Exception as e: _LOGGER.exception("Uncaught error in message handler callback: %s", e) - return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler) + self._mqtt_unsub = await self._mqtt_session.subscribe(self._subscribe_topic, message_handler) + + def unsub_wrapper() -> None: + if self._mqtt_unsub is not None: + self._mqtt_unsub() + self._mqtt_unsub = None + + return unsub_wrapper async def _resolve_future_with_lock(self, message: RoborockMessage) -> None: """Resolve waiting future with proper locking.""" @@ -113,3 +128,10 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> async with self._queue_lock: self._waiting_queue.pop(request_id, None) raise + + +def create_mqtt_channel( + user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice +) -> MqttChannel: + """Create a V1Channel for the given device.""" + return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params) diff --git a/roborock/devices/traits/dyad.py b/roborock/devices/traits/dyad.py new file mode 100644 index 00000000..55ef6fbc --- /dev/null +++ b/roborock/devices/traits/dyad.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import logging +from typing import Any + +from roborock.roborock_message import RoborockDyadDataProtocol + +from ..a01_channel import send_decoded_command +from ..mqtt_channel import MqttChannel +from .trait import Trait + +_LOGGER = logging.getLogger(__name__) + +__all__ = [ + "DyadApi", +] + + +class DyadApi(Trait): + """API for interacting with Dyad devices.""" + + name = "dyad" + + def __init__(self, channel: MqttChannel) -> None: + """Initialize the Dyad API.""" + self._channel = channel + + async def query_values(self, protocols: list[RoborockDyadDataProtocol]) -> dict[RoborockDyadDataProtocol, Any]: + """Query the device for the values of the given Dyad protocols.""" + params = {RoborockDyadDataProtocol.ID_QUERY: [int(p) for p in protocols]} + return await send_decoded_command(self._channel, params) + + async def set_value(self, protocol: RoborockDyadDataProtocol, value: Any) -> dict[RoborockDyadDataProtocol, Any]: + """Set a value for a specific protocol on the device.""" + params = {protocol: value} + return await send_decoded_command(self._channel, params) diff --git a/roborock/devices/traits/status.py b/roborock/devices/traits/status.py new file mode 100644 index 00000000..86fc6623 --- /dev/null +++ b/roborock/devices/traits/status.py @@ -0,0 +1,48 @@ +"""Module for Roborock V1 devices. + +This interface is experimental and subject to breaking changes without notice +until the API is stable. +""" + +import logging + +from roborock.containers import ( + HomeDataProduct, + ModelStatus, + S7MaxVStatus, + Status, +) +from roborock.roborock_typing import RoborockCommand + +from ..v1_rpc_channel import V1RpcChannel +from .trait import Trait + +_LOGGER = logging.getLogger(__name__) + +__all__ = [ + "Status", +] + + +class StatusTrait(Trait): + """Unified Roborock device class with automatic connection setup.""" + + name = "status" + + def __init__(self, product_info: HomeDataProduct, rpc_channel: V1RpcChannel) -> None: + """Initialize the RoborockDevice. + + The device takes ownership of the V1 channel for communication with the device. + Use `connect()` to establish the connection, which will set up the appropriate + protocol channel. Use `close()` to clean up all connections. + """ + self._product_info = product_info + self._rpc_channel = rpc_channel + + async def get_status(self) -> Status: + """Get the current status of the device. + + This is a placeholder command and will likely be changed/moved in the future. + """ + status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus) + return await self._rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type) diff --git a/roborock/devices/traits/trait.py b/roborock/devices/traits/trait.py new file mode 100644 index 00000000..dca29b85 --- /dev/null +++ b/roborock/devices/traits/trait.py @@ -0,0 +1,10 @@ +"""Trait module for Roborock devices.""" + +from abc import ABC + + +class Trait(ABC): + """API for interacting with Roborock devices.""" + + name: str + """Name of the API.""" diff --git a/roborock/devices/traits/zeo.py b/roborock/devices/traits/zeo.py new file mode 100644 index 00000000..4e6bb6b0 --- /dev/null +++ b/roborock/devices/traits/zeo.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import logging +from typing import Any + +from roborock.roborock_message import RoborockZeoProtocol + +from ..a01_channel import send_decoded_command +from ..mqtt_channel import MqttChannel +from .trait import Trait + +_LOGGER = logging.getLogger(__name__) + +__all__ = [ + "ZeoApi", +] + + +class ZeoApi(Trait): + """API for interacting with Zeo devices.""" + + name = "zeo" + + def __init__(self, channel: MqttChannel) -> None: + """Initialize the Zeo API.""" + self._channel = channel + + async def query_values(self, protocols: list[RoborockZeoProtocol]) -> dict[RoborockZeoProtocol, Any]: + """Query the device for the values of the given protocols.""" + params = {RoborockZeoProtocol.ID_QUERY: [int(p) for p in protocols]} + return await send_decoded_command(self._channel, params) + + async def set_value(self, protocol: RoborockZeoProtocol, value: Any) -> dict[RoborockZeoProtocol, Any]: + """Set a value for a specific protocol on the device.""" + params = {protocol: value} + return await send_decoded_command(self._channel, params) diff --git a/roborock/devices/v1_channel.py b/roborock/devices/v1_channel.py index dd2e7a14..e37719f7 100644 --- a/roborock/devices/v1_channel.py +++ b/roborock/devices/v1_channel.py @@ -18,6 +18,7 @@ from roborock.roborock_message import RoborockMessage from roborock.roborock_typing import RoborockCommand +from .channel import Channel from .local_channel import LocalChannel, LocalSession, create_local_session from .mqtt_channel import MqttChannel from .v1_rpc_channel import V1RpcChannel, create_combined_rpc_channel, create_mqtt_rpc_channel @@ -31,7 +32,7 @@ _T = TypeVar("_T", bound=RoborockBase) -class V1Channel: +class V1Channel(Channel): """Unified V1 protocol channel with automatic MQTT/local connection handling. This channel abstracts away the complexity of choosing between MQTT and local @@ -63,6 +64,11 @@ def __init__( self._callback: Callable[[RoborockMessage], None] | None = None self._networking_info: NetworkInfo | None = None + @property + def is_connected(self) -> bool: + """Return whether MQTT connection is available.""" + return self.is_mqtt_connected or self.is_local_connected + @property def is_local_connected(self) -> bool: """Return whether local connection is available.""" diff --git a/roborock/protocols/a01_protocol.py b/roborock/protocols/a01_protocol.py index 71f7a2d6..1955cba9 100644 --- a/roborock/protocols/a01_protocol.py +++ b/roborock/protocols/a01_protocol.py @@ -20,7 +20,9 @@ A01_VERSION = b"A01" -def encode_mqtt_payload(data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any]) -> RoborockMessage: +def encode_mqtt_payload( + data: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any], +) -> RoborockMessage: """Encode payload for A01 commands over MQTT.""" dps_data = {"dps": data} payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size) diff --git a/tests/devices/test_device_manager.py b/tests/devices/test_device_manager.py index 26a4bc0f..d549bc16 100644 --- a/tests/devices/test_device_manager.py +++ b/tests/devices/test_device_manager.py @@ -6,7 +6,6 @@ import pytest from roborock.containers import HomeData, UserData -from roborock.devices.device import DeviceVersion from roborock.devices.device_manager import create_device_manager, create_home_data_api from roborock.exceptions import RoborockException @@ -63,13 +62,11 @@ async def test_with_device() -> None: assert len(devices) == 1 assert devices[0].duid == "abc123" assert devices[0].name == "Roborock S7 MaxV" - assert devices[0].device_version == DeviceVersion.V1 device = await device_manager.get_device("abc123") assert device is not None assert device.duid == "abc123" assert device.name == "Roborock S7 MaxV" - assert device.device_version == DeviceVersion.V1 await device_manager.close() diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py index 37a7c27c..e171e33d 100644 --- a/tests/devices/test_mqtt_channel.py +++ b/tests/devices/test_mqtt_channel.py @@ -119,7 +119,9 @@ async def test_mqtt_channel(mqtt_session: Mock, mqtt_channel: MqttChannel) -> No assert mqtt_session.subscribe.called assert mqtt_session.subscribe.call_args[0][0] == "rr/m/o/user123/username/abc123" - assert result == unsub + unsub.assert_not_called() + result() + unsub.assert_called_once() async def test_send_message_success( diff --git a/tests/devices/test_device.py b/tests/devices/test_v1_device.py similarity index 57% rename from tests/devices/test_device.py rename to tests/devices/test_v1_device.py index 52c941ca..10ef8073 100644 --- a/tests/devices/test_device.py +++ b/tests/devices/test_v1_device.py @@ -5,7 +5,9 @@ import pytest from roborock.containers import HomeData, S7MaxVStatus, UserData -from roborock.devices.device import DeviceVersion, RoborockDevice +from roborock.devices.device import RoborockDevice +from roborock.devices.traits.status import StatusTrait +from roborock.devices.traits.trait import Trait from .. import mock_data @@ -15,22 +17,38 @@ @pytest.fixture(autouse=True, name="channel") -def channel_fixture() -> AsyncMock: +def device_channel_fixture() -> AsyncMock: + """Fixture to set up the channel for tests.""" + return AsyncMock() + + +@pytest.fixture(autouse=True, name="rpc_channel") +def rpc_channel_fixture() -> AsyncMock: """Fixture to set up the channel for tests.""" return AsyncMock() @pytest.fixture(autouse=True, name="device") -def device_fixture(channel: AsyncMock) -> RoborockDevice: +def device_fixture(channel: AsyncMock, traits: list[Trait]) -> RoborockDevice: """Fixture to set up the device for tests.""" return RoborockDevice( - USER_DATA, device_info=HOME_DATA.devices[0], - product_info=HOME_DATA.products[0], - v1_channel=channel, + channel=channel, + traits=traits, ) +@pytest.fixture(autouse=True, name="traits") +def traits_fixture(rpc_channel: AsyncMock) -> list[Trait]: + """Fixture to set up the V1 API for tests.""" + return [ + StatusTrait( + product_info=HOME_DATA.products[0], + rpc_channel=rpc_channel, + ) + ] + + async def test_device_connection(device: RoborockDevice, channel: AsyncMock) -> None: """Test the Device connection setup.""" @@ -41,7 +59,6 @@ async def test_device_connection(device: RoborockDevice, channel: AsyncMock) -> assert device.duid == "abc123" assert device.name == "Roborock S7 MaxV" - assert device.device_version == DeviceVersion.V1 assert not subscribe.called @@ -53,14 +70,17 @@ async def test_device_connection(device: RoborockDevice, channel: AsyncMock) -> assert unsub.called -async def test_device_get_status_command(device: RoborockDevice, channel: AsyncMock) -> None: +async def test_device_get_status_command(device: RoborockDevice, rpc_channel: AsyncMock) -> None: """Test the device get_status command.""" # Mock response for get_status command - channel.rpc_channel.send_command.return_value = STATUS + rpc_channel.send_command.return_value = STATUS # Test get_status and verify the command was sent - status = await device.get_status() - assert channel.rpc_channel.send_command.called + status_api = device.traits["status"] + assert isinstance(status_api, StatusTrait) + assert status_api is not None + status = await status_api.get_status() + assert rpc_channel.send_command.called # Verify the result assert status is not None From 4370d30a701cfc2beaa13e835dc1d6ab339f80d7 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Thu, 7 Aug 2025 08:03:26 -0700 Subject: [PATCH 2/7] fix: add safety check for trait creation --- roborock/devices/device.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 157b3bca..846e6486 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -47,6 +47,8 @@ def __init__( self._channel = channel self._unsub: Callable[[], None] | None = None self._trait_map = {trait.name: trait for trait in traits} + if len(self._trait_map) != len(traits): + raise ValueError("Duplicate trait names found in traits list") @property def duid(self) -> str: From 1a0d3df7959d15b57799c566451dce71473dd936 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Thu, 7 Aug 2025 08:09:34 -0700 Subject: [PATCH 3/7] feat: Update cli with v1 status trait --- roborock/cli.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/roborock/cli.py b/roborock/cli.py index 64c0e224..0852352d 100644 --- a/roborock/cli.py +++ b/roborock/cli.py @@ -117,8 +117,11 @@ async def home_data_cache() -> HomeData: click.echo("MQTT session started. Querying devices...") for device in devices: + if not (status_trait := device.traits.get("status")): + click.echo(f"Device {device.name} does not have a status trait") + continue try: - status = await device.get_status() + status = await status_trait.get_status() except RoborockException as e: click.echo(f"Failed to get status for {device.name}: {e}") else: From c5899253d556f980e9234dd508b3f7e4d3468e10 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Thu, 7 Aug 2025 08:25:25 -0700 Subject: [PATCH 4/7] chore: address code review feedback --- roborock/devices/device.py | 2 +- roborock/devices/traits/status.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 846e6486..b1c5fbda 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -38,7 +38,7 @@ def __init__( ) -> None: """Initialize the RoborockDevice. - The device takes ownership of the V1 channel for communication with the device. + The device takes ownership of the channel for communication with the device. Use `connect()` to establish the connection, which will set up the appropriate protocol channel. Use `close()` to clean up all connections. """ diff --git a/roborock/devices/traits/status.py b/roborock/devices/traits/status.py index 86fc6623..d7d622d9 100644 --- a/roborock/devices/traits/status.py +++ b/roborock/devices/traits/status.py @@ -30,12 +30,7 @@ class StatusTrait(Trait): name = "status" def __init__(self, product_info: HomeDataProduct, rpc_channel: V1RpcChannel) -> None: - """Initialize the RoborockDevice. - - The device takes ownership of the V1 channel for communication with the device. - Use `connect()` to establish the connection, which will set up the appropriate - protocol channel. Use `close()` to clean up all connections. - """ + """Initialize the StatusTrait.""" self._product_info = product_info self._rpc_channel = rpc_channel From 9165448ff08424d0d438ab2dd46419ba63d67ac7 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Thu, 7 Aug 2025 08:21:30 -0700 Subject: [PATCH 5/7] chore: Update roborock/devices/v1_channel.py Co-authored-by: Luke Lashley --- roborock/devices/v1_channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/devices/v1_channel.py b/roborock/devices/v1_channel.py index e37719f7..7ab9e301 100644 --- a/roborock/devices/v1_channel.py +++ b/roborock/devices/v1_channel.py @@ -66,7 +66,7 @@ def __init__( @property def is_connected(self) -> bool: - """Return whether MQTT connection is available.""" + """Return whether any connection is available.""" return self.is_mqtt_connected or self.is_local_connected @property From 440b0e26a786e2dc4170173bd135f2b09bba56b4 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Thu, 7 Aug 2025 20:33:28 -0700 Subject: [PATCH 6/7] chore: Revert encode_mqtt_payload typing change --- roborock/protocols/a01_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/protocols/a01_protocol.py b/roborock/protocols/a01_protocol.py index 1955cba9..2f02203d 100644 --- a/roborock/protocols/a01_protocol.py +++ b/roborock/protocols/a01_protocol.py @@ -21,7 +21,7 @@ def encode_mqtt_payload( - data: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any], + data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any], ) -> RoborockMessage: """Encode payload for A01 commands over MQTT.""" dps_data = {"dps": data} From ba06e49b5df994feb22b273edb3293e30032aad6 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Thu, 7 Aug 2025 20:37:39 -0700 Subject: [PATCH 7/7] fix: update mqtt payload encoding signature --- roborock/protocols/a01_protocol.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/roborock/protocols/a01_protocol.py b/roborock/protocols/a01_protocol.py index 2f02203d..5aa5ffb2 100644 --- a/roborock/protocols/a01_protocol.py +++ b/roborock/protocols/a01_protocol.py @@ -21,7 +21,9 @@ def encode_mqtt_payload( - data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any], + data: dict[RoborockDyadDataProtocol, Any] + | dict[RoborockZeoProtocol, Any] + | dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any], ) -> RoborockMessage: """Encode payload for A01 commands over MQTT.""" dps_data = {"dps": data}