Skip to content

Commit 054b3fb

Browse files
committed
chore: Move a01 encoding and decoding to a separate module
1 parent 9dd8c22 commit 054b3fb

File tree

6 files changed

+393
-98
lines changed

6 files changed

+393
-98
lines changed

roborock/protocols/a01_protocol.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Roborock A01 Protocol encoding and decoding."""
2+
3+
import json
4+
import logging
5+
from typing import Any
6+
7+
from Crypto.Cipher import AES
8+
from Crypto.Util.Padding import pad, unpad
9+
10+
from roborock.exceptions import RoborockException
11+
from roborock.roborock_message import (
12+
RoborockDyadDataProtocol,
13+
RoborockMessage,
14+
RoborockMessageProtocol,
15+
RoborockZeoProtocol,
16+
)
17+
18+
_LOGGER = logging.getLogger(__name__)
19+
20+
A01_VERSION = b"A01"
21+
22+
23+
def encode_mqtt_payload(data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any]) -> RoborockMessage:
24+
"""Encode payload for A01 commands over MQTT."""
25+
dps_data = {"dps": data}
26+
payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size)
27+
return RoborockMessage(
28+
protocol=RoborockMessageProtocol.RPC_REQUEST,
29+
version=A01_VERSION,
30+
payload=payload,
31+
)
32+
33+
34+
def decode_rpc_response(message: RoborockMessage) -> dict[int, Any]:
35+
"""Decode a V1 RPC_RESPONSE message."""
36+
if not message.payload:
37+
raise RoborockException("Invalid A01 message format: missing payload")
38+
try:
39+
unpadded = unpad(message.payload, AES.block_size)
40+
except ValueError as err:
41+
raise RoborockException(f"Unable to unpad A01 payload: {err}")
42+
43+
try:
44+
payload = json.loads(unpadded.decode())
45+
except (json.JSONDecodeError, TypeError) as e:
46+
raise RoborockException(f"Invalid A01 message payload: {e} for {message.payload!r}") from e
47+
48+
datapoints = payload.get("dps", {})
49+
if not isinstance(datapoints, dict):
50+
raise RoborockException(f"Invalid A01 message format: 'dps' should be a dictionary for {message.payload!r}")
51+
try:
52+
return {int(key): value for key, value in datapoints.items()}
53+
except ValueError:
54+
raise RoborockException(f"Invalid A01 message format: 'dps' key should be an integer for {message.payload!r}")
Lines changed: 86 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
import dataclasses
2-
import json
31
import logging
4-
import typing
52
from abc import ABC, abstractmethod
63
from collections.abc import Callable
74
from datetime import time
8-
9-
from Crypto.Cipher import AES
10-
from Crypto.Util.Padding import unpad
5+
from typing import Any
116

127
from roborock import DeviceData
138
from roborock.api import RoborockClient
@@ -33,6 +28,8 @@
3328
ZeoTemperature,
3429
)
3530
from roborock.containers import DyadProductInfo, DyadSndState, RoborockCategory
31+
from roborock.exceptions import RoborockException
32+
from roborock.protocols.a01_protocol import decode_rpc_response
3633
from roborock.roborock_message import (
3734
RoborockDyadDataProtocol,
3835
RoborockMessage,
@@ -43,111 +40,120 @@
4340
_LOGGER = logging.getLogger(__name__)
4441

4542

46-
@dataclasses.dataclass
47-
class A01ProtocolCacheEntry:
48-
post_process_fn: Callable
49-
value: typing.Any | None = None
50-
51-
5243
# Right now this cache is not active, it was too much complexity for the initial addition of dyad.
53-
protocol_entries = {
54-
RoborockDyadDataProtocol.STATUS: A01ProtocolCacheEntry(lambda val: RoborockDyadStateCode(val).name),
55-
RoborockDyadDataProtocol.SELF_CLEAN_MODE: A01ProtocolCacheEntry(lambda val: DyadSelfCleanMode(val).name),
56-
RoborockDyadDataProtocol.SELF_CLEAN_LEVEL: A01ProtocolCacheEntry(lambda val: DyadSelfCleanLevel(val).name),
57-
RoborockDyadDataProtocol.WARM_LEVEL: A01ProtocolCacheEntry(lambda val: DyadWarmLevel(val).name),
58-
RoborockDyadDataProtocol.CLEAN_MODE: A01ProtocolCacheEntry(lambda val: DyadCleanMode(val).name),
59-
RoborockDyadDataProtocol.SUCTION: A01ProtocolCacheEntry(lambda val: DyadSuction(val).name),
60-
RoborockDyadDataProtocol.WATER_LEVEL: A01ProtocolCacheEntry(lambda val: DyadWaterLevel(val).name),
61-
RoborockDyadDataProtocol.BRUSH_SPEED: A01ProtocolCacheEntry(lambda val: DyadBrushSpeed(val).name),
62-
RoborockDyadDataProtocol.POWER: A01ProtocolCacheEntry(lambda val: int(val)),
63-
RoborockDyadDataProtocol.AUTO_DRY: A01ProtocolCacheEntry(lambda val: bool(val)),
64-
RoborockDyadDataProtocol.MESH_LEFT: A01ProtocolCacheEntry(lambda val: int(360000 - val * 60)),
65-
RoborockDyadDataProtocol.BRUSH_LEFT: A01ProtocolCacheEntry(lambda val: int(360000 - val * 60)),
66-
RoborockDyadDataProtocol.ERROR: A01ProtocolCacheEntry(lambda val: DyadError(val).name),
67-
RoborockDyadDataProtocol.VOLUME_SET: A01ProtocolCacheEntry(lambda val: int(val)),
68-
RoborockDyadDataProtocol.STAND_LOCK_AUTO_RUN: A01ProtocolCacheEntry(lambda val: bool(val)),
69-
RoborockDyadDataProtocol.AUTO_DRY_MODE: A01ProtocolCacheEntry(lambda val: bool(val)),
70-
RoborockDyadDataProtocol.SILENT_DRY_DURATION: A01ProtocolCacheEntry(lambda val: int(val)), # in minutes
71-
RoborockDyadDataProtocol.SILENT_MODE: A01ProtocolCacheEntry(lambda val: bool(val)),
72-
RoborockDyadDataProtocol.SILENT_MODE_START_TIME: A01ProtocolCacheEntry(
73-
lambda val: time(hour=int(val / 60), minute=val % 60)
44+
DYAD_PROTOCOL_ENTRIES: dict[RoborockDyadDataProtocol, Callable] = {
45+
RoborockDyadDataProtocol.STATUS: lambda val: RoborockDyadStateCode(val).name,
46+
RoborockDyadDataProtocol.SELF_CLEAN_MODE: lambda val: DyadSelfCleanMode(val).name,
47+
RoborockDyadDataProtocol.SELF_CLEAN_LEVEL: lambda val: DyadSelfCleanLevel(val).name,
48+
RoborockDyadDataProtocol.WARM_LEVEL: lambda val: DyadWarmLevel(val).name,
49+
RoborockDyadDataProtocol.CLEAN_MODE: lambda val: DyadCleanMode(val).name,
50+
RoborockDyadDataProtocol.SUCTION: lambda val: DyadSuction(val).name,
51+
RoborockDyadDataProtocol.WATER_LEVEL: lambda val: DyadWaterLevel(val).name,
52+
RoborockDyadDataProtocol.BRUSH_SPEED: lambda val: DyadBrushSpeed(val).name,
53+
RoborockDyadDataProtocol.POWER: lambda val: int(val),
54+
RoborockDyadDataProtocol.AUTO_DRY: lambda val: bool(val),
55+
RoborockDyadDataProtocol.MESH_LEFT: lambda val: int(360000 - val * 60),
56+
RoborockDyadDataProtocol.BRUSH_LEFT: lambda val: int(360000 - val * 60),
57+
RoborockDyadDataProtocol.ERROR: lambda val: DyadError(val).name,
58+
RoborockDyadDataProtocol.VOLUME_SET: lambda val: int(val),
59+
RoborockDyadDataProtocol.STAND_LOCK_AUTO_RUN: lambda val: bool(val),
60+
RoborockDyadDataProtocol.AUTO_DRY_MODE: lambda val: bool(val),
61+
RoborockDyadDataProtocol.SILENT_DRY_DURATION: lambda val: int(val), # in minutes
62+
RoborockDyadDataProtocol.SILENT_MODE: lambda val: bool(val),
63+
RoborockDyadDataProtocol.SILENT_MODE_START_TIME: lambda val: time(
64+
hour=int(val / 60), minute=val % 60
7465
), # in minutes since 00:00
75-
RoborockDyadDataProtocol.SILENT_MODE_END_TIME: A01ProtocolCacheEntry(
76-
lambda val: time(hour=int(val / 60), minute=val % 60)
66+
RoborockDyadDataProtocol.SILENT_MODE_END_TIME: lambda val: time(
67+
hour=int(val / 60), minute=val % 60
7768
), # in minutes since 00:00
78-
RoborockDyadDataProtocol.RECENT_RUN_TIME: A01ProtocolCacheEntry(
79-
lambda val: [int(v) for v in val.split(",")]
80-
), # minutes of cleaning in past few days.
81-
RoborockDyadDataProtocol.TOTAL_RUN_TIME: A01ProtocolCacheEntry(lambda val: int(val)),
82-
RoborockDyadDataProtocol.SND_STATE: A01ProtocolCacheEntry(lambda val: DyadSndState.from_dict(val)),
83-
RoborockDyadDataProtocol.PRODUCT_INFO: A01ProtocolCacheEntry(lambda val: DyadProductInfo.from_dict(val)),
69+
RoborockDyadDataProtocol.RECENT_RUN_TIME: lambda val: [
70+
int(v) for v in val.split(",")
71+
], # minutes of cleaning in past few days.
72+
RoborockDyadDataProtocol.TOTAL_RUN_TIME: lambda val: int(val),
73+
RoborockDyadDataProtocol.SND_STATE: lambda val: DyadSndState.from_dict(val),
74+
RoborockDyadDataProtocol.PRODUCT_INFO: lambda val: DyadProductInfo.from_dict(val),
8475
}
8576

86-
zeo_data_protocol_entries = {
77+
ZEO_PROTOCOL_ENTRIES: dict[RoborockZeoProtocol, Callable] = {
8778
# ro
88-
RoborockZeoProtocol.STATE: A01ProtocolCacheEntry(lambda val: ZeoState(val).name),
89-
RoborockZeoProtocol.COUNTDOWN: A01ProtocolCacheEntry(lambda val: int(val)),
90-
RoborockZeoProtocol.WASHING_LEFT: A01ProtocolCacheEntry(lambda val: int(val)),
91-
RoborockZeoProtocol.ERROR: A01ProtocolCacheEntry(lambda val: ZeoError(val).name),
92-
RoborockZeoProtocol.TIMES_AFTER_CLEAN: A01ProtocolCacheEntry(lambda val: int(val)),
93-
RoborockZeoProtocol.DETERGENT_EMPTY: A01ProtocolCacheEntry(lambda val: bool(val)),
94-
RoborockZeoProtocol.SOFTENER_EMPTY: A01ProtocolCacheEntry(lambda val: bool(val)),
79+
RoborockZeoProtocol.STATE: lambda val: ZeoState(val).name,
80+
RoborockZeoProtocol.COUNTDOWN: lambda val: int(val),
81+
RoborockZeoProtocol.WASHING_LEFT: lambda val: int(val),
82+
RoborockZeoProtocol.ERROR: lambda val: ZeoError(val).name,
83+
RoborockZeoProtocol.TIMES_AFTER_CLEAN: lambda val: int(val),
84+
RoborockZeoProtocol.DETERGENT_EMPTY: lambda val: bool(val),
85+
RoborockZeoProtocol.SOFTENER_EMPTY: lambda val: bool(val),
9586
# rw
96-
RoborockZeoProtocol.MODE: A01ProtocolCacheEntry(lambda val: ZeoMode(val).name),
97-
RoborockZeoProtocol.PROGRAM: A01ProtocolCacheEntry(lambda val: ZeoProgram(val).name),
98-
RoborockZeoProtocol.TEMP: A01ProtocolCacheEntry(lambda val: ZeoTemperature(val).name),
99-
RoborockZeoProtocol.RINSE_TIMES: A01ProtocolCacheEntry(lambda val: ZeoRinse(val).name),
100-
RoborockZeoProtocol.SPIN_LEVEL: A01ProtocolCacheEntry(lambda val: ZeoSpin(val).name),
101-
RoborockZeoProtocol.DRYING_MODE: A01ProtocolCacheEntry(lambda val: ZeoDryingMode(val).name),
102-
RoborockZeoProtocol.DETERGENT_TYPE: A01ProtocolCacheEntry(lambda val: ZeoDetergentType(val).name),
103-
RoborockZeoProtocol.SOFTENER_TYPE: A01ProtocolCacheEntry(lambda val: ZeoSoftenerType(val).name),
104-
RoborockZeoProtocol.SOUND_SET: A01ProtocolCacheEntry(lambda val: bool(val)),
87+
RoborockZeoProtocol.MODE: lambda val: ZeoMode(val).name,
88+
RoborockZeoProtocol.PROGRAM: lambda val: ZeoProgram(val).name,
89+
RoborockZeoProtocol.TEMP: lambda val: ZeoTemperature(val).name,
90+
RoborockZeoProtocol.RINSE_TIMES: lambda val: ZeoRinse(val).name,
91+
RoborockZeoProtocol.SPIN_LEVEL: lambda val: ZeoSpin(val).name,
92+
RoborockZeoProtocol.DRYING_MODE: lambda val: ZeoDryingMode(val).name,
93+
RoborockZeoProtocol.DETERGENT_TYPE: lambda val: ZeoDetergentType(val).name,
94+
RoborockZeoProtocol.SOFTENER_TYPE: lambda val: ZeoSoftenerType(val).name,
95+
RoborockZeoProtocol.SOUND_SET: lambda val: bool(val),
10596
}
10697

10798

99+
def convert_dyad_value(protocol: int, value: Any) -> Any:
100+
"""Convert a dyad protocol value to its corresponding type."""
101+
protocol_value = RoborockDyadDataProtocol(protocol)
102+
if (converter := DYAD_PROTOCOL_ENTRIES.get(protocol_value)) is not None:
103+
return converter(value)
104+
return None
105+
106+
107+
def convert_zeo_value(protocol: int, value: Any) -> Any:
108+
"""Convert a zeo protocol value to its corresponding type."""
109+
protocol_value = RoborockZeoProtocol(protocol)
110+
if (converter := ZEO_PROTOCOL_ENTRIES.get(protocol_value)) is not None:
111+
return converter(value)
112+
return None
113+
114+
108115
class RoborockClientA01(RoborockClient, ABC):
109116
"""Roborock client base class for A01 devices."""
110117

118+
value_converter: Callable[[int, Any], Any] | None = None
119+
111120
def __init__(self, device_info: DeviceData, category: RoborockCategory):
112121
"""Initialize the Roborock client."""
113122
super().__init__(device_info)
114-
self.category = category
123+
if category == RoborockCategory.WET_DRY_VAC:
124+
self.value_converter = convert_dyad_value
125+
elif category == RoborockCategory.WASHING_MACHINE:
126+
self.value_converter = convert_zeo_value
127+
else:
128+
_LOGGER.debug("Device category %s is not (yet) supported", category)
129+
self.value_converter = None
115130

116131
def on_message_received(self, messages: list[RoborockMessage]) -> None:
132+
if self.value_converter is None:
133+
return
117134
for message in messages:
118135
protocol = message.protocol
119136
if message.payload and protocol in [
120137
RoborockMessageProtocol.RPC_RESPONSE,
121138
RoborockMessageProtocol.GENERAL_REQUEST,
122139
]:
123-
payload = message.payload
124140
try:
125-
payload = unpad(payload, AES.block_size)
126-
except Exception as err:
127-
self._logger.debug("Failed to unpad payload: %s", err)
141+
data_points = decode_rpc_response(message)
142+
except RoborockException as err:
143+
self._logger.error("Failed to decode message %s: %s", message, err)
128144
continue
129-
payload_json = json.loads(payload.decode())
130-
for data_point_number, data_point in payload_json.get("dps").items():
131-
data_point_protocol: RoborockDyadDataProtocol | RoborockZeoProtocol
132-
self._logger.debug("received msg with dps, protocol: %s, %s", data_point_number, protocol)
133-
entries: dict
134-
if self.category == RoborockCategory.WET_DRY_VAC:
135-
data_point_protocol = RoborockDyadDataProtocol(int(data_point_number))
136-
entries = protocol_entries
137-
elif self.category == RoborockCategory.WASHING_MACHINE:
138-
data_point_protocol = RoborockZeoProtocol(int(data_point_number))
139-
entries = zeo_data_protocol_entries
140-
else:
141-
continue
142-
if data_point_protocol in entries:
143-
# Auto convert into data struct we want.
144-
converted_response = entries[data_point_protocol].post_process_fn(data_point)
145+
for data_point_number, data_point in data_points.items():
146+
if converted_response := self.value_converter(data_point_number, data_point):
145147
queue = self._waiting_queue.get(int(data_point_number))
146148
if queue and queue.protocol == protocol:
147149
queue.set_result(converted_response)
150+
else:
151+
self._logger.warning(
152+
"Received unknown data point %s for protocol %s, ignoring", data_point_number, protocol
153+
)
148154

149155
@abstractmethod
150156
async def update_values(
151157
self, dyad_data_protocols: list[RoborockDyadDataProtocol | RoborockZeoProtocol]
152-
) -> dict[RoborockDyadDataProtocol | RoborockZeoProtocol, typing.Any]:
158+
) -> dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any]:
153159
"""This should handle updating for each given protocol."""

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import typing
55

66
from Crypto.Cipher import AES
7-
from Crypto.Util.Padding import pad, unpad
7+
from Crypto.Util.Padding import unpad
88

99
from roborock.cloud_api import RoborockMqttClient
1010
from roborock.containers import DeviceData, RoborockCategory, UserData
1111
from roborock.exceptions import RoborockException
12+
from roborock.protocols.a01_protocol import encode_mqtt_payload
1213
from roborock.roborock_message import (
1314
RoborockDyadDataProtocol,
1415
RoborockMessage,
@@ -46,11 +47,13 @@ async def send_message(self, roborock_message: RoborockMessage):
4647
# self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
4748
payload = json.loads(unpad(roborock_message.payload, AES.block_size))
4849
futures = []
50+
self._logger.debug("Sending message: %s", payload)
4951
if "10000" in payload["dps"]:
5052
for dps in json.loads(payload["dps"]["10000"]):
5153
futures.append(self._async_response(dps, response_protocol))
5254
self._send_msg_raw(m)
5355
responses = await asyncio.gather(*futures, return_exceptions=True)
56+
self._logger.debug("Received responses: %s", responses)
5457
dps_responses: dict[int, typing.Any] = {}
5558
if "10000" in payload["dps"]:
5659
for i, dps in enumerate(json.loads(payload["dps"]["10000"])):
@@ -65,24 +68,14 @@ async def send_message(self, roborock_message: RoborockMessage):
6568
async def update_values(
6669
self, dyad_data_protocols: list[RoborockDyadDataProtocol | RoborockZeoProtocol]
6770
) -> dict[RoborockDyadDataProtocol | RoborockZeoProtocol, typing.Any]:
68-
payload = {"dps": {RoborockDyadDataProtocol.ID_QUERY: str([int(protocol) for protocol in dyad_data_protocols])}}
69-
return await self.send_message(
70-
RoborockMessage(
71-
protocol=RoborockMessageProtocol.RPC_REQUEST,
72-
version=b"A01",
73-
payload=pad(json.dumps(payload).encode("utf-8"), AES.block_size),
74-
)
71+
message = encode_mqtt_payload(
72+
{RoborockDyadDataProtocol.ID_QUERY: str([int(protocol) for protocol in dyad_data_protocols])}
7573
)
74+
return await self.send_message(message)
7675

7776
async def set_value(
7877
self, protocol: RoborockDyadDataProtocol | RoborockZeoProtocol, value: typing.Any
7978
) -> dict[int, typing.Any]:
8079
"""Set a value for a specific protocol on the A01 device."""
81-
payload = {"dps": {int(protocol): value}}
82-
return await self.send_message(
83-
RoborockMessage(
84-
protocol=RoborockMessageProtocol.RPC_REQUEST,
85-
version=b"A01",
86-
payload=pad(json.dumps(payload).encode("utf-8"), AES.block_size),
87-
)
88-
)
80+
message = encode_mqtt_payload({protocol: value})
81+
return await self.send_message(message)

tests/protocols/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for the protocols package."""

0 commit comments

Comments
 (0)