Skip to content

Commit 7507423

Browse files
authored
chore: Create module for v1 request encoding (#413)
* chore: Create module for v1 request encoding * chore: Delete tests/devices/test_v1_protocol.py * feat: Simplify local payload encoding by rejecting any cloud commands sent locally
1 parent ec780c9 commit 7507423

File tree

6 files changed

+111
-56
lines changed

6 files changed

+111
-56
lines changed

roborock/protocols/v1_protocol.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Roborock V1 Protocol Encoder."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
import math
7+
import time
8+
from collections.abc import Callable
9+
from dataclasses import dataclass, field
10+
from typing import Any
11+
12+
from roborock.roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
13+
from roborock.roborock_typing import RoborockCommand
14+
from roborock.util import get_next_int
15+
16+
CommandType = RoborockCommand | str
17+
ParamsType = list | dict | int | None
18+
19+
20+
@dataclass(frozen=True, kw_only=True)
21+
class SecurityData:
22+
"""Security data included in the request for some V1 commands."""
23+
24+
endpoint: str
25+
nonce: bytes
26+
27+
def to_dict(self) -> dict[str, Any]:
28+
"""Convert security data to a dictionary for sending in the payload."""
29+
return {"security": {"endpoint": self.endpoint, "nonce": self.nonce.hex().lower()}}
30+
31+
32+
@dataclass
33+
class RequestMessage:
34+
"""Data structure for v1 RoborockMessage payloads."""
35+
36+
method: RoborockCommand | str
37+
params: ParamsType
38+
timestamp: int = field(default_factory=lambda: math.floor(time.time()))
39+
request_id: int = field(default_factory=lambda: get_next_int(10000, 32767))
40+
41+
def as_payload(self, security_data: SecurityData | None) -> bytes:
42+
"""Convert the request arguments to a dictionary."""
43+
inner = {
44+
"id": self.request_id,
45+
"method": self.method,
46+
"params": self.params or [],
47+
**(security_data.to_dict() if security_data else {}),
48+
}
49+
return bytes(
50+
json.dumps(
51+
{
52+
"dps": {"101": json.dumps(inner, separators=(",", ":"))},
53+
"t": self.timestamp,
54+
},
55+
separators=(",", ":"),
56+
).encode()
57+
)
58+
59+
60+
def create_mqtt_payload_encoder(security_data: SecurityData) -> Callable[[CommandType, ParamsType], RoborockMessage]:
61+
"""Create a payload encoder for V1 commands over MQTT."""
62+
63+
def _get_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
64+
"""Build the payload for a V1 command."""
65+
request = RequestMessage(method=method, params=params)
66+
payload = request.as_payload(security_data) # always secure
67+
return RoborockMessage(
68+
timestamp=request.timestamp,
69+
protocol=RoborockMessageProtocol.RPC_REQUEST,
70+
payload=payload,
71+
)
72+
73+
return _get_payload
74+
75+
76+
def encode_local_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
77+
"""Encode payload for V1 commands over local connection."""
78+
79+
request = RequestMessage(method=method, params=params)
80+
payload = request.as_payload(security_data=None)
81+
82+
message_retry: MessageRetry | None = None
83+
if method == RoborockCommand.RETRY_REQUEST and isinstance(params, dict):
84+
message_retry = MessageRetry(method=method, retry_id=params["retry_id"])
85+
86+
return RoborockMessage(
87+
timestamp=request.timestamp,
88+
protocol=RoborockMessageProtocol.GENERAL_REQUEST,
89+
payload=payload,
90+
message_retry=message_retry,
91+
)

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import dataclasses
33
import json
4-
import math
54
import struct
65
import time
76
from abc import ABC, abstractmethod
@@ -54,15 +53,15 @@
5453
RoborockMessage,
5554
RoborockMessageProtocol,
5655
)
57-
from roborock.util import RepeatableTask, get_next_int, unpack_list
56+
from roborock.util import RepeatableTask, unpack_list
57+
58+
CUSTOM_COMMANDS = {RoborockCommand.GET_MAP_CALIBRATION}
5859

5960
COMMANDS_SECURED = {
6061
RoborockCommand.GET_MAP_V1,
6162
RoborockCommand.GET_MULTI_MAP,
6263
}
6364

64-
CUSTOM_COMMANDS = {RoborockCommand.GET_MAP_CALIBRATION}
65-
6665
CLOUD_REQUIRED = COMMANDS_SECURED.union(CUSTOM_COMMANDS)
6766

6867
WASH_N_FILL_DOCK = [
@@ -340,35 +339,6 @@ async def load_multi_map(self, map_flag: int) -> None:
340339
"""Load the map into the vacuum's memory."""
341340
await self.send_command(RoborockCommand.LOAD_MULTI_MAP, [map_flag])
342341

343-
def _get_payload(
344-
self,
345-
method: RoborockCommand | str,
346-
params: list | dict | int | None = None,
347-
secured=False,
348-
):
349-
timestamp = math.floor(time.time())
350-
request_id = get_next_int(10000, 32767)
351-
inner = {
352-
"id": request_id,
353-
"method": method,
354-
"params": params or [],
355-
}
356-
if secured:
357-
inner["security"] = {
358-
"endpoint": self._endpoint,
359-
"nonce": self._nonce.hex().lower(),
360-
}
361-
payload = bytes(
362-
json.dumps(
363-
{
364-
"dps": {"101": json.dumps(inner, separators=(",", ":"))},
365-
"t": timestamp,
366-
},
367-
separators=(",", ":"),
368-
).encode()
369-
)
370-
return request_id, timestamp, payload
371-
372342
@abstractmethod
373343
async def _send_command(
374344
self,

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
66
from ..exceptions import VacuumError
7-
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
7+
from ..protocols.v1_protocol import encode_local_payload
8+
from ..roborock_message import RoborockMessage, RoborockMessageProtocol
89
from ..util import RoborockLoggerAdapter
9-
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
10+
from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1
1011

1112
_LOGGER = logging.getLogger(__name__)
1213

@@ -21,26 +22,16 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
2122
self.queue_timeout = queue_timeout
2223
self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER)
2324

24-
def build_roborock_message(
25-
self, method: RoborockCommand | str, params: list | dict | int | None = None
26-
) -> RoborockMessage:
27-
secured = True if method in COMMANDS_SECURED else False
28-
request_id, timestamp, payload = self._get_payload(method, params, secured)
29-
self._logger.debug("Building message id %s for method %s", request_id, method)
30-
request_protocol = RoborockMessageProtocol.GENERAL_REQUEST
31-
message_retry: MessageRetry | None = None
32-
if method == RoborockCommand.RETRY_REQUEST and isinstance(params, dict):
33-
message_retry = MessageRetry(method=params["method"], retry_id=params["retry_id"])
34-
return RoborockMessage(
35-
timestamp=timestamp, protocol=request_protocol, payload=payload, message_retry=message_retry
36-
)
37-
3825
async def _send_command(
3926
self,
4027
method: RoborockCommand | str,
4128
params: list | dict | int | None = None,
4229
):
43-
roborock_message = self.build_roborock_message(method, params)
30+
if method in CLOUD_REQUIRED:
31+
raise RoborockException(f"Method {method} is not supported over local connection")
32+
33+
roborock_message = encode_local_payload(method, params)
34+
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id(), method)
4435
return await self.send_message(roborock_message)
4536

4637
async def send_message(self, roborock_message: RoborockMessage):

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..containers import DeviceData, UserData
1212
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
1313
from ..protocol import Utils
14+
from ..protocols.v1_protocol import SecurityData, create_mqtt_payload_encoder
1415
from ..roborock_message import (
1516
RoborockMessage,
1617
RoborockMessageProtocol,
@@ -36,6 +37,9 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
3637
RoborockClientV1.__init__(self, device_info, endpoint)
3738
self.queue_timeout = queue_timeout
3839
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
40+
self._payload_encoder = create_mqtt_payload_encoder(
41+
SecurityData(endpoint=self._endpoint, nonce=self._nonce),
42+
)
3943

4044
async def send_message(self, roborock_message: RoborockMessage):
4145
await self.validate_connection()
@@ -78,10 +82,9 @@ async def _send_command(
7882
if method in CUSTOM_COMMANDS:
7983
# When we have more custom commands do something more complicated here
8084
return await self._get_calibration_points()
81-
request_id, timestamp, payload = self._get_payload(method, params, True)
82-
self._logger.debug("Building message id %s for method %s", request_id, method)
83-
request_protocol = RoborockMessageProtocol.RPC_REQUEST
84-
roborock_message = RoborockMessage(timestamp=timestamp, protocol=request_protocol, payload=payload)
85+
86+
roborock_message = self._payload_encoder(method, params)
87+
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id, method)
8588
return await self.send_message(roborock_message)
8689

8790
async def _get_calibration_points(self):

tests/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ async def test_get_room_mapping(
282282
)
283283
response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message))
284284

285-
with patch("roborock.version_1_apis.roborock_client_v1.get_next_int", return_value=test_request_id):
285+
with patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id):
286286
room_mapping = await connected_mqtt_client.get_room_mapping()
287287

288288
assert room_mapping == [

tests/test_local_api_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ async def test_get_room_mapping(
8686
)
8787
response_queue.put(message)
8888

89-
with patch("roborock.version_1_apis.roborock_client_v1.get_next_int", return_value=test_request_id):
89+
with patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id):
9090
room_mapping = await connected_local_client.get_room_mapping()
9191

9292
assert room_mapping == [

0 commit comments

Comments
 (0)