Skip to content

chore: Minor refactoring creating functions for transforming bytes #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions roborock/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .api import KEEPALIVE, RoborockClient
from .containers import DeviceData, UserData
from .exceptions import RoborockException, VacuumError
from .protocol import MessageParser, md5hex
from .protocol import Decoder, Encoder, create_mqtt_decoder, create_mqtt_encoder, md5hex
from .roborock_future import RoborockFuture

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,6 +74,8 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
self._waiting_queue: dict[int, RoborockFuture] = {}
self._mutex = Lock()
self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key)
self._encoder: Encoder = create_mqtt_encoder(device_info.device.local_key)

def _mqtt_on_connect(self, *args, **kwargs):
_, __, ___, rc, ____ = args
Expand Down Expand Up @@ -102,7 +104,7 @@ def _mqtt_on_connect(self, *args, **kwargs):
def _mqtt_on_message(self, *args, **kwargs):
client, __, msg = args
try:
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
messages = self._decoder(msg.payload)
super().on_message_received(messages)
except Exception as ex:
self._logger.exception(ex)
Expand Down
12 changes: 5 additions & 7 deletions roborock/local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import DeviceData
from .api import RoborockClient
from .exceptions import RoborockConnectionException, RoborockException
from .protocol import MessageParser
from .protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
from .roborock_message import RoborockMessage, RoborockMessageProtocol

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,20 +44,18 @@ def __init__(self, device_data: DeviceData):
self.host = device_data.host
self._batch_structs: list[RoborockMessage] = []
self._executing = False
self.remaining = b""
self.transport: Transport | None = None
self._mutex = Lock()
self.keep_alive_task: TimerHandle | None = None
RoborockClient.__init__(self, device_data)
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
self._decoder: Decoder = create_local_decoder(device_data.device.local_key)

def _data_received(self, message):
"""Called when data is received from the transport."""
if self.remaining:
message = self.remaining + message
self.remaining = b""
parser_msg, self.remaining = MessageParser.parse(message, local_key=self.device_info.device.local_key)
self.on_message_received(parser_msg)
parsed_msg = self._decoder(message)
self.on_message_received(parsed_msg)

def _connection_lost(self, exc: Exception | None):
"""Called when the transport connection is lost."""
Expand Down
53 changes: 53 additions & 0 deletions roborock/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,56 @@ def build(

MessageParser: _Parser = _Parser(_Messages, True)
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)


Decoder = Callable[[bytes], list[RoborockMessage]]
Encoder = Callable[[RoborockMessage], bytes]


def create_mqtt_decoder(local_key: str) -> Decoder:
"""Create a decoder for MQTT messages."""

def decode(data: bytes) -> list[RoborockMessage]:
"""Parse the given data into Roborock messages."""
messages, _ = MessageParser.parse(data, local_key)
return messages

return decode


def create_mqtt_encoder(local_key: str) -> Encoder:
"""Create an encoder for MQTT messages."""

def encode(messages: RoborockMessage) -> bytes:
"""Build the given Roborock messages into a byte string."""
return MessageParser.build(messages, local_key, prefixed=False)

return encode


def create_local_decoder(local_key: str) -> Decoder:
"""Create a decoder for local API messages."""

# This buffer is used to accumulate bytes until a complete message can be parsed.
# It is defined outside the decode function to maintain state across calls.
buffer: bytes = b""

def decode(bytes: bytes) -> list[RoborockMessage]:
"""Parse the given data into Roborock messages."""
nonlocal buffer
buffer += bytes
parsed_messages, remaining = MessageParser.parse(buffer, local_key=local_key)
buffer = remaining
return parsed_messages

return decode


def create_local_encoder(local_key: str) -> Encoder:
"""Create an encoder for local API messages."""

def encode(message: RoborockMessage) -> bytes:
"""Called when data is sent to the transport."""
return MessageParser.build(message, local_key=local_key)

return encode
4 changes: 1 addition & 3 deletions roborock/version_1_apis/roborock_local_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
from ..exceptions import VacuumError
from ..protocol import MessageParser
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
from ..util import RoborockLoggerAdapter
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
Expand Down Expand Up @@ -57,8 +56,7 @@ async def send_message(self, roborock_message: RoborockMessage):
response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
if request_id is None:
raise RoborockException(f"Failed build message {roborock_message}")
local_key = self.device_info.device.local_key
msg = MessageParser.build(roborock_message, local_key=local_key)
msg = self._encoder(roborock_message)
if method:
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
# Send the command to the Roborock device
Expand Down
6 changes: 2 additions & 4 deletions roborock/version_1_apis/roborock_mqtt_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..containers import DeviceData, UserData
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
from ..protocol import MessageParser, Utils
from ..protocol import Utils
from ..roborock_message import (
RoborockMessage,
RoborockMessageProtocol,
Expand Down Expand Up @@ -47,9 +47,7 @@ async def send_message(self, roborock_message: RoborockMessage):
response_protocol = (
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
)

local_key = self.device_info.device.local_key
msg = MessageParser.build(roborock_message, local_key, False)
msg = self._encoder(roborock_message)
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
async_response = self._async_response(request_id, response_protocol)
self._send_msg_raw(msg)
Expand Down
4 changes: 1 addition & 3 deletions roborock/version_a01_apis/roborock_mqtt_client_a01.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from roborock.cloud_api import RoborockMqttClient
from roborock.containers import DeviceData, RoborockCategory, UserData
from roborock.exceptions import RoborockException
from roborock.protocol import MessageParser
from roborock.roborock_message import (
RoborockDyadDataProtocol,
RoborockMessage,
Expand Down Expand Up @@ -43,8 +42,7 @@ async def send_message(self, roborock_message: RoborockMessage):
await self.validate_connection()
response_protocol = RoborockMessageProtocol.RPC_RESPONSE

local_key = self.device_info.device.local_key
m = MessageParser.build(roborock_message, local_key, prefixed=False)
m = self._encoder(roborock_message)
# self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
payload = json.loads(unpad(roborock_message.payload, AES.block_size))
futures = []
Expand Down