Skip to content

feat: Support both a01 and v1 device types with traits #425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion roborock/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@ async def home_data_cache() -> HomeData:

click.echo("MQTT session started. Querying devices...")
for device in devices:
if not (status_trait := device.traits.get("status")):
click.echo(f"Device {device.name} does not have a status trait")
continue
try:
status = await device.get_status()
status = await status_trait.get_status()
except RoborockException as e:
click.echo(f"Failed to get status for {device.name}: {e}")
else:
Expand Down
43 changes: 43 additions & 0 deletions roborock/devices/a01_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Thin wrapper around the MQTT channel for Roborock A01 devices."""

from __future__ import annotations

import logging
from typing import Any, overload

from roborock.protocols.a01_protocol import (
decode_rpc_response,
encode_mqtt_payload,
)
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol

from .mqtt_channel import MqttChannel

_LOGGER = logging.getLogger(__name__)


@overload
async def send_decoded_command(
mqtt_channel: MqttChannel,
params: dict[RoborockDyadDataProtocol, Any],
) -> dict[RoborockDyadDataProtocol, Any]:
...


@overload
async def send_decoded_command(
mqtt_channel: MqttChannel,
params: dict[RoborockZeoProtocol, Any],
) -> dict[RoborockZeoProtocol, Any]:
...


async def send_decoded_command(
mqtt_channel: MqttChannel,
params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any],
) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]:
"""Send a command on the MQTT channel and get a decoded response."""
_LOGGER.debug("Sending MQTT command: %s", params)
roborock_message = encode_mqtt_payload(params)
response = await mqtt_channel.send_message(roborock_message)
return decode_rpc_response(response) # type: ignore[return-value]
27 changes: 27 additions & 0 deletions roborock/devices/channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Low-level interface for connections to Roborock devices."""

import logging
from collections.abc import Callable
from typing import Protocol

from roborock.roborock_message import RoborockMessage

_LOGGER = logging.getLogger(__name__)


class Channel(Protocol):
"""A generic channel for establishing a connection with a Roborock device.

Individual channel implementations have their own methods for speaking to
the device that hide some of the protocol specific complexity, but they
are still specialized for the device type and protocol.
"""

@property
def is_connected(self) -> bool:
"""Return true if the channel is connected."""
...

async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
"""Subscribe to messages from the device."""
...
92 changes: 30 additions & 62 deletions roborock/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,101 +4,72 @@
until the API is stable.
"""

import enum
import logging
from collections.abc import Callable
from functools import cached_property

from roborock.containers import (
HomeDataDevice,
HomeDataProduct,
ModelStatus,
S7MaxVStatus,
Status,
UserData,
)
from abc import ABC
from collections.abc import Callable, Mapping
from types import MappingProxyType

from roborock.containers import HomeDataDevice
from roborock.roborock_message import RoborockMessage
from roborock.roborock_typing import RoborockCommand

from .v1_channel import V1Channel
from .channel import Channel
from .traits.trait import Trait

_LOGGER = logging.getLogger(__name__)

__all__ = [
"RoborockDevice",
"DeviceVersion",
]


class DeviceVersion(enum.StrEnum):
"""Enum for device versions."""

V1 = "1.0"
A01 = "A01"
UNKNOWN = "unknown"

class RoborockDevice(ABC):
"""A generic channel for establishing a connection with a Roborock device.

class RoborockDevice:
"""Unified Roborock device class with automatic connection setup."""
Individual channel implementations have their own methods for speaking to
the device that hide some of the protocol specific complexity, but they
are still specialized for the device type and protocol.
"""

def __init__(
self,
user_data: UserData,
device_info: HomeDataDevice,
product_info: HomeDataProduct,
v1_channel: V1Channel,
channel: Channel,
traits: list[Trait],
) -> None:
"""Initialize the RoborockDevice.

The device takes ownership of the V1 channel for communication with the device.
The device takes ownership of the channel for communication with the device.
Use `connect()` to establish the connection, which will set up the appropriate
protocol channel. Use `close()` to clean up all connections.
"""
self._user_data = user_data
self._device_info = device_info
self._product_info = product_info
self._v1_channel = v1_channel
self._duid = device_info.duid
self._name = device_info.name
self._channel = channel
self._unsub: Callable[[], None] | None = None
self._trait_map = {trait.name: trait for trait in traits}
if len(self._trait_map) != len(traits):
raise ValueError("Duplicate trait names found in traits list")

@property
def duid(self) -> str:
"""Return the device unique identifier (DUID)."""
return self._device_info.duid
return self._duid

@property
def name(self) -> str:
"""Return the device name."""
return self._device_info.name

@cached_property
def device_version(self) -> str:
"""Return the device version.

At the moment this is a simple check against the product version (pv) of the device
and used as a placeholder for upcoming functionality for devices that will behave
differently based on the version and capabilities.
"""
if self._device_info.pv == DeviceVersion.V1.value:
return DeviceVersion.V1
elif self._device_info.pv == DeviceVersion.A01.value:
return DeviceVersion.A01
_LOGGER.warning(
"Unknown device version %s for device %s, using default UNKNOWN",
self._device_info.pv,
self._device_info.name,
)
return DeviceVersion.UNKNOWN
return self._name

@property
def is_connected(self) -> bool:
"""Return whether the device is connected."""
return self._v1_channel.is_mqtt_connected or self._v1_channel.is_local_connected
return self._channel.is_connected

async def connect(self) -> None:
"""Connect to the device using the appropriate protocol channel."""
if self._unsub:
raise ValueError("Already connected to the device")
self._unsub = await self._v1_channel.subscribe(self._on_message)
self._unsub = await self._channel.subscribe(self._on_message)
_LOGGER.info("Connected to V1 device %s", self.name)

async def close(self) -> None:
Expand All @@ -111,10 +82,7 @@ def _on_message(self, message: RoborockMessage) -> None:
"""Handle incoming messages from the device."""
_LOGGER.debug("Received message from device: %s", message)

async def get_status(self) -> Status:
"""Get the current status of the device.

This is a placeholder command and will likely be changed/moved in the future.
"""
status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus)
return await self._v1_channel.rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type)
@property
def traits(self) -> Mapping[str, Trait]:
"""Return the traits of the device."""
return MappingProxyType(self._trait_map)
Comment on lines +86 to +88
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this just stay as a dict as is? I don't really know what MappingProxyType is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this returns a dict its actually mutable so a caller could theoretically change it. It may be overkill here...

46 changes: 36 additions & 10 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
"""Module for discovering Roborock devices."""

import asyncio
import enum
import logging
from collections.abc import Awaitable, Callable

from roborock.code_mappings import RoborockCategory
from roborock.containers import (
HomeData,
HomeDataDevice,
HomeDataProduct,
UserData,
)
from roborock.devices.device import DeviceVersion, RoborockDevice
from roborock.devices.device import RoborockDevice
from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.mqtt.session import MqttSession
from roborock.protocol import create_mqtt_params
from roborock.web_api import RoborockApiClient

from .channel import Channel
from .mqtt_channel import create_mqtt_channel
from .traits.dyad import DyadApi
from .traits.status import StatusTrait
from .traits.trait import Trait
from .traits.zeo import ZeoApi
from .v1_channel import create_v1_channel

_LOGGER = logging.getLogger(__name__)
Expand All @@ -33,6 +41,14 @@
DeviceCreator = Callable[[HomeDataDevice, HomeDataProduct], RoborockDevice]


class DeviceVersion(enum.StrEnum):
"""Enum for device versions."""

V1 = "1.0"
A01 = "A01"
UNKNOWN = "unknown"


class DeviceManager:
"""Central manager for Roborock device discovery and connections."""

Expand Down Expand Up @@ -114,15 +130,25 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi)
mqtt_session = await create_mqtt_session(mqtt_params)

def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
# Check device version and only support V1 for now
if device.pv != DeviceVersion.V1.value:
raise NotImplementedError(
f"Device {device.name} has version {device.pv}, but only V1 devices "
f"are supported through the unified interface."
)
# Create V1 channel that handles both MQTT and local connections
v1_channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device)
return RoborockDevice(user_data, device, product, v1_channel)
channel: Channel
traits: list[Trait] = []
# TODO: Define a registration mechanism/factory for v1 traits
match device.pv:
case DeviceVersion.V1:
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device)
traits.append(StatusTrait(product, channel.rpc_channel))
case DeviceVersion.A01:
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
match product.category:
case RoborockCategory.WET_DRY_VAC:
traits.append(DyadApi(mqtt_channel))
case RoborockCategory.WASHING_MACHINE:
traits.append(ZeoApi(mqtt_channel))
case _:
raise NotImplementedError(f"Device {device.name} has unsupported category {product.category}")
case _:
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
return RoborockDevice(device, channel, traits)

manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
await manager.discover_devices()
Expand Down
4 changes: 3 additions & 1 deletion roborock/devices/local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
from roborock.roborock_message import RoborockMessage

from .channel import Channel

_LOGGER = logging.getLogger(__name__)
_PORT = 58867

Expand All @@ -30,7 +32,7 @@ def connection_lost(self, exc: Exception | None) -> None:
self.connection_lost_cb(exc)


class LocalChannel:
class LocalChannel(Channel):
"""Simple RPC-style channel for communicating with a device over a local network.

Handles request/response correlation and timeouts, but leaves message
Expand Down
28 changes: 25 additions & 3 deletions roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
from collections.abc import Callable
from json import JSONDecodeError

from roborock.containers import RRiot
from roborock.containers import HomeDataDevice, RRiot, UserData
from roborock.exceptions import RoborockException
from roborock.mqtt.session import MqttParams, MqttSession
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
from roborock.roborock_message import RoborockMessage

from .channel import Channel

_LOGGER = logging.getLogger(__name__)


class MqttChannel:
class MqttChannel(Channel):
"""Simple RPC-style channel for communicating with a device over MQTT.

Handles request/response correlation and timeouts, but leaves message
Expand All @@ -33,6 +35,12 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot:
self._decoder = create_mqtt_decoder(local_key)
self._encoder = create_mqtt_encoder(local_key)
self._queue_lock = asyncio.Lock()
self._mqtt_unsub: Callable[[], None] | None = None

@property
def is_connected(self) -> bool:
"""Return true if the channel is connected."""
return (self._mqtt_unsub is not None) and self._mqtt_session.connected

@property
def _publish_topic(self) -> str:
Expand Down Expand Up @@ -67,7 +75,14 @@ def message_handler(payload: bytes) -> None:
except Exception as e:
_LOGGER.exception("Uncaught error in message handler callback: %s", e)

return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
self._mqtt_unsub = await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)

def unsub_wrapper() -> None:
if self._mqtt_unsub is not None:
self._mqtt_unsub()
self._mqtt_unsub = None

return unsub_wrapper

async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
"""Resolve waiting future with proper locking."""
Expand Down Expand Up @@ -113,3 +128,10 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) ->
async with self._queue_lock:
self._waiting_queue.pop(request_id, None)
raise


def create_mqtt_channel(
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
) -> MqttChannel:
"""Create a V1Channel for the given device."""
return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
Loading
Loading