Skip to content

Commit 1990912

Browse files
committed
feat: Simplify local payload encoding by rejecting any cloud commands sent locally
1 parent b61ced3 commit 1990912

File tree

3 files changed

+24
-32
lines changed

3 files changed

+24
-32
lines changed

roborock/protocols/v1_protocol.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313
from roborock.roborock_typing import RoborockCommand
1414
from roborock.util import get_next_int
1515

16-
# All mqtt commands are sent securely. Only local commands in this list are secured.
17-
COMMANDS_SECURED = {
18-
RoborockCommand.GET_MAP_V1,
19-
RoborockCommand.GET_MULTI_MAP,
20-
}
21-
2216
CommandType = RoborockCommand | str
2317
ParamsType = list | dict | int | None
2418

@@ -79,24 +73,19 @@ def _get_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
7973
return _get_payload
8074

8175

82-
def create_local_payload_encoder(security_data: SecurityData) -> Callable[[CommandType, ParamsType], RoborockMessage]:
83-
"""Create a payload encoder for V1 commands over local connection."""
84-
85-
def _get_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
86-
"""Build the payload for a V1 command."""
87-
request = RequestMessage(method=method, params=params)
88-
is_secured = request.method in COMMANDS_SECURED
89-
payload = request.as_payload(security_data if is_secured else None)
76+
def encode_local_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
77+
"""Encode payload for V1 commands over local connection."""
9078

91-
message_retry: MessageRetry | None = None
92-
if method == RoborockCommand.RETRY_REQUEST and isinstance(params, dict):
93-
message_retry = MessageRetry(method=method, retry_id=params["retry_id"])
79+
request = RequestMessage(method=method, params=params)
80+
payload = request.as_payload(security_data=None)
9481

95-
return RoborockMessage(
96-
timestamp=request.timestamp,
97-
protocol=RoborockMessageProtocol.GENERAL_REQUEST,
98-
payload=payload,
99-
message_retry=message_retry,
100-
)
82+
message_retry: MessageRetry | None = None
83+
if method == RoborockCommand.RETRY_REQUEST and isinstance(params, dict):
84+
message_retry = MessageRetry(method=method, retry_id=params["retry_id"])
10185

102-
return _get_payload
86+
return RoborockMessage(
87+
timestamp=request.timestamp,
88+
protocol=RoborockMessageProtocol.GENERAL_REQUEST,
89+
payload=payload,
90+
message_retry=message_retry,
91+
)

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,14 @@
5454
RoborockMessageProtocol,
5555
)
5656
from roborock.util import RepeatableTask, unpack_list
57-
from roborock.protocols.v1_protocol import COMMANDS_SECURED
58-
5957

6058
CUSTOM_COMMANDS = {RoborockCommand.GET_MAP_CALIBRATION}
6159

60+
COMMANDS_SECURED = {
61+
RoborockCommand.GET_MAP_V1,
62+
RoborockCommand.GET_MULTI_MAP,
63+
}
64+
6265
CLOUD_REQUIRED = COMMANDS_SECURED.union(CUSTOM_COMMANDS)
6366

6467
WASH_N_FILL_DOCK = [

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
66
from ..exceptions import VacuumError
7-
from ..protocols.v1_protocol import SecurityData, create_local_payload_encoder
7+
from ..protocols.v1_protocol import encode_local_payload
88
from ..roborock_message import RoborockMessage, RoborockMessageProtocol
99
from ..util import RoborockLoggerAdapter
10-
from .roborock_client_v1 import RoborockClientV1
10+
from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1
1111

1212
_LOGGER = logging.getLogger(__name__)
1313

@@ -21,16 +21,16 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
2121
RoborockClientV1.__init__(self, device_data, "abc")
2222
self.queue_timeout = queue_timeout
2323
self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER)
24-
self._payload_encoder = create_local_payload_encoder(
25-
SecurityData(endpoint=self._endpoint, nonce=self._nonce),
26-
)
2724

2825
async def _send_command(
2926
self,
3027
method: RoborockCommand | str,
3128
params: list | dict | int | None = None,
3229
):
33-
roborock_message = self._payload_encoder(method, params)
30+
if method in CLOUD_REQUIRED:
31+
raise RoborockException(f"Method {method} is not supported over local connection")
32+
33+
roborock_message = encode_local_payload(method, params)
3434
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id(), method)
3535
return await self.send_message(roborock_message)
3636

0 commit comments

Comments
 (0)