Skip to content

chore: Create module for v1 request encoding #413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions roborock/protocols/v1_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Roborock V1 Protocol Encoder."""

from __future__ import annotations

import json
import math
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any

from roborock.roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
from roborock.roborock_typing import RoborockCommand
from roborock.util import get_next_int

CommandType = RoborockCommand | str
ParamsType = list | dict | int | None


@dataclass(frozen=True, kw_only=True)
class SecurityData:
"""Security data included in the request for some V1 commands."""

endpoint: str
nonce: bytes

def to_dict(self) -> dict[str, Any]:
"""Convert security data to a dictionary for sending in the payload."""
return {"security": {"endpoint": self.endpoint, "nonce": self.nonce.hex().lower()}}


@dataclass
class RequestMessage:
"""Data structure for v1 RoborockMessage payloads."""

method: RoborockCommand | str
params: ParamsType
timestamp: int = field(default_factory=lambda: math.floor(time.time()))
request_id: int = field(default_factory=lambda: get_next_int(10000, 32767))

def as_payload(self, security_data: SecurityData | None) -> bytes:
"""Convert the request arguments to a dictionary."""
inner = {
"id": self.request_id,
"method": self.method,
"params": self.params or [],
**(security_data.to_dict() if security_data else {}),
}
return bytes(
json.dumps(
{
"dps": {"101": json.dumps(inner, separators=(",", ":"))},
"t": self.timestamp,
},
separators=(",", ":"),
).encode()
)


def create_mqtt_payload_encoder(security_data: SecurityData) -> Callable[[CommandType, ParamsType], RoborockMessage]:
"""Create a payload encoder for V1 commands over MQTT."""

def _get_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
"""Build the payload for a V1 command."""
request = RequestMessage(method=method, params=params)
payload = request.as_payload(security_data) # always secure
return RoborockMessage(
timestamp=request.timestamp,
protocol=RoborockMessageProtocol.RPC_REQUEST,
payload=payload,
)

return _get_payload


def encode_local_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
"""Encode payload for V1 commands over local connection."""

request = RequestMessage(method=method, params=params)
payload = request.as_payload(security_data=None)

message_retry: MessageRetry | None = None
if method == RoborockCommand.RETRY_REQUEST and isinstance(params, dict):
message_retry = MessageRetry(method=method, retry_id=params["retry_id"])

return RoborockMessage(
timestamp=request.timestamp,
protocol=RoborockMessageProtocol.GENERAL_REQUEST,
payload=payload,
message_retry=message_retry,
)
36 changes: 3 additions & 33 deletions roborock/version_1_apis/roborock_client_v1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import dataclasses
import json
import math
import struct
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -54,15 +53,15 @@
RoborockMessage,
RoborockMessageProtocol,
)
from roborock.util import RepeatableTask, get_next_int, unpack_list
from roborock.util import RepeatableTask, unpack_list

CUSTOM_COMMANDS = {RoborockCommand.GET_MAP_CALIBRATION}

COMMANDS_SECURED = {
RoborockCommand.GET_MAP_V1,
RoborockCommand.GET_MULTI_MAP,
}

CUSTOM_COMMANDS = {RoborockCommand.GET_MAP_CALIBRATION}

CLOUD_REQUIRED = COMMANDS_SECURED.union(CUSTOM_COMMANDS)

WASH_N_FILL_DOCK = [
Expand Down Expand Up @@ -340,35 +339,6 @@ async def load_multi_map(self, map_flag: int) -> None:
"""Load the map into the vacuum's memory."""
await self.send_command(RoborockCommand.LOAD_MULTI_MAP, [map_flag])

def _get_payload(
self,
method: RoborockCommand | str,
params: list | dict | int | None = None,
secured=False,
):
timestamp = math.floor(time.time())
request_id = get_next_int(10000, 32767)
inner = {
"id": request_id,
"method": method,
"params": params or [],
}
if secured:
inner["security"] = {
"endpoint": self._endpoint,
"nonce": self._nonce.hex().lower(),
}
payload = bytes(
json.dumps(
{
"dps": {"101": json.dumps(inner, separators=(",", ":"))},
"t": timestamp,
},
separators=(",", ":"),
).encode()
)
return request_id, timestamp, payload

@abstractmethod
async def _send_command(
self,
Expand Down
25 changes: 8 additions & 17 deletions roborock/version_1_apis/roborock_local_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
from ..exceptions import VacuumError
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
from ..protocols.v1_protocol import encode_local_payload
from ..roborock_message import RoborockMessage, RoborockMessageProtocol
from ..util import RoborockLoggerAdapter
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1

_LOGGER = logging.getLogger(__name__)

Expand All @@ -21,26 +22,16 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
self.queue_timeout = queue_timeout
self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER)

def build_roborock_message(
self, method: RoborockCommand | str, params: list | dict | int | None = None
) -> RoborockMessage:
secured = True if method in COMMANDS_SECURED else False
request_id, timestamp, payload = self._get_payload(method, params, secured)
self._logger.debug("Building message id %s for method %s", request_id, method)
request_protocol = RoborockMessageProtocol.GENERAL_REQUEST
message_retry: MessageRetry | None = None
if method == RoborockCommand.RETRY_REQUEST and isinstance(params, dict):
message_retry = MessageRetry(method=params["method"], retry_id=params["retry_id"])
return RoborockMessage(
timestamp=timestamp, protocol=request_protocol, payload=payload, message_retry=message_retry
)

async def _send_command(
self,
method: RoborockCommand | str,
params: list | dict | int | None = None,
):
roborock_message = self.build_roborock_message(method, params)
if method in CLOUD_REQUIRED:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now a breaking change to just hard fail when trying to do this.

raise RoborockException(f"Method {method} is not supported over local connection")

roborock_message = encode_local_payload(method, params)
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id(), method)
return await self.send_message(roborock_message)

async def send_message(self, roborock_message: RoborockMessage):
Expand Down
11 changes: 7 additions & 4 deletions roborock/version_1_apis/roborock_mqtt_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..containers import DeviceData, UserData
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
from ..protocol import Utils
from ..protocols.v1_protocol import SecurityData, create_mqtt_payload_encoder
from ..roborock_message import (
RoborockMessage,
RoborockMessageProtocol,
Expand All @@ -36,6 +37,9 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
RoborockClientV1.__init__(self, device_info, endpoint)
self.queue_timeout = queue_timeout
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
self._payload_encoder = create_mqtt_payload_encoder(
SecurityData(endpoint=self._endpoint, nonce=self._nonce),
)

async def send_message(self, roborock_message: RoborockMessage):
await self.validate_connection()
Expand Down Expand Up @@ -78,10 +82,9 @@ async def _send_command(
if method in CUSTOM_COMMANDS:
# When we have more custom commands do something more complicated here
return await self._get_calibration_points()
request_id, timestamp, payload = self._get_payload(method, params, True)
self._logger.debug("Building message id %s for method %s", request_id, method)
request_protocol = RoborockMessageProtocol.RPC_REQUEST
roborock_message = RoborockMessage(timestamp=timestamp, protocol=request_protocol, payload=payload)

roborock_message = self._payload_encoder(method, params)
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id, method)
return await self.send_message(roborock_message)

async def _get_calibration_points(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def test_get_room_mapping(
)
response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message))

with patch("roborock.version_1_apis.roborock_client_v1.get_next_int", return_value=test_request_id):
with patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id):
room_mapping = await connected_mqtt_client.get_room_mapping()

assert room_mapping == [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_local_api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def test_get_room_mapping(
)
response_queue.put(message)

with patch("roborock.version_1_apis.roborock_client_v1.get_next_int", return_value=test_request_id):
with patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id):
room_mapping = await connected_local_client.get_room_mapping()

assert room_mapping == [
Expand Down
Loading