Skip to content

Commit b19dbaa

Browse files
authored
chore: Minor refactoring creating functions for transforming bytes (#397)
This is in preparation for sharin with new device/mqtt code.
1 parent 9e0ddf8 commit b19dbaa

File tree

6 files changed

+66
-19
lines changed

6 files changed

+66
-19
lines changed

roborock/cloud_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .api import KEEPALIVE, RoborockClient
1414
from .containers import DeviceData, UserData
1515
from .exceptions import RoborockException, VacuumError
16-
from .protocol import MessageParser, md5hex
16+
from .protocol import Decoder, Encoder, create_mqtt_decoder, create_mqtt_encoder, md5hex
1717
from .roborock_future import RoborockFuture
1818

1919
_LOGGER = logging.getLogger(__name__)
@@ -74,6 +74,8 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
7474
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
7575
self._waiting_queue: dict[int, RoborockFuture] = {}
7676
self._mutex = Lock()
77+
self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key)
78+
self._encoder: Encoder = create_mqtt_encoder(device_info.device.local_key)
7779

7880
def _mqtt_on_connect(self, *args, **kwargs):
7981
_, __, ___, rc, ____ = args
@@ -102,7 +104,7 @@ def _mqtt_on_connect(self, *args, **kwargs):
102104
def _mqtt_on_message(self, *args, **kwargs):
103105
client, __, msg = args
104106
try:
105-
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
107+
messages = self._decoder(msg.payload)
106108
super().on_message_received(messages)
107109
except Exception as ex:
108110
self._logger.exception(ex)

roborock/local_api.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import DeviceData
1313
from .api import RoborockClient
1414
from .exceptions import RoborockConnectionException, RoborockException
15-
from .protocol import MessageParser
15+
from .protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
1616
from .roborock_message import RoborockMessage, RoborockMessageProtocol
1717

1818
_LOGGER = logging.getLogger(__name__)
@@ -44,20 +44,18 @@ def __init__(self, device_data: DeviceData):
4444
self.host = device_data.host
4545
self._batch_structs: list[RoborockMessage] = []
4646
self._executing = False
47-
self.remaining = b""
4847
self.transport: Transport | None = None
4948
self._mutex = Lock()
5049
self.keep_alive_task: TimerHandle | None = None
5150
RoborockClient.__init__(self, device_data)
5251
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
52+
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
53+
self._decoder: Decoder = create_local_decoder(device_data.device.local_key)
5354

5455
def _data_received(self, message):
5556
"""Called when data is received from the transport."""
56-
if self.remaining:
57-
message = self.remaining + message
58-
self.remaining = b""
59-
parser_msg, self.remaining = MessageParser.parse(message, local_key=self.device_info.device.local_key)
60-
self.on_message_received(parser_msg)
57+
parsed_msg = self._decoder(message)
58+
self.on_message_received(parsed_msg)
6159

6260
def _connection_lost(self, exc: Exception | None):
6361
"""Called when the transport connection is lost."""

roborock/protocol.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,56 @@ def build(
359359

360360
MessageParser: _Parser = _Parser(_Messages, True)
361361
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)
362+
363+
364+
Decoder = Callable[[bytes], list[RoborockMessage]]
365+
Encoder = Callable[[RoborockMessage], bytes]
366+
367+
368+
def create_mqtt_decoder(local_key: str) -> Decoder:
369+
"""Create a decoder for MQTT messages."""
370+
371+
def decode(data: bytes) -> list[RoborockMessage]:
372+
"""Parse the given data into Roborock messages."""
373+
messages, _ = MessageParser.parse(data, local_key)
374+
return messages
375+
376+
return decode
377+
378+
379+
def create_mqtt_encoder(local_key: str) -> Encoder:
380+
"""Create an encoder for MQTT messages."""
381+
382+
def encode(messages: RoborockMessage) -> bytes:
383+
"""Build the given Roborock messages into a byte string."""
384+
return MessageParser.build(messages, local_key, prefixed=False)
385+
386+
return encode
387+
388+
389+
def create_local_decoder(local_key: str) -> Decoder:
390+
"""Create a decoder for local API messages."""
391+
392+
# This buffer is used to accumulate bytes until a complete message can be parsed.
393+
# It is defined outside the decode function to maintain state across calls.
394+
buffer: bytes = b""
395+
396+
def decode(bytes: bytes) -> list[RoborockMessage]:
397+
"""Parse the given data into Roborock messages."""
398+
nonlocal buffer
399+
buffer += bytes
400+
parsed_messages, remaining = MessageParser.parse(buffer, local_key=local_key)
401+
buffer = remaining
402+
return parsed_messages
403+
404+
return decode
405+
406+
407+
def create_local_encoder(local_key: str) -> Encoder:
408+
"""Create an encoder for local API messages."""
409+
410+
def encode(message: RoborockMessage) -> bytes:
411+
"""Called when data is sent to the transport."""
412+
return MessageParser.build(message, local_key=local_key)
413+
414+
return encode

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
66
from ..exceptions import VacuumError
7-
from ..protocol import MessageParser
87
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
98
from ..util import RoborockLoggerAdapter
109
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
@@ -57,8 +56,7 @@ async def send_message(self, roborock_message: RoborockMessage):
5756
response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
5857
if request_id is None:
5958
raise RoborockException(f"Failed build message {roborock_message}")
60-
local_key = self.device_info.device.local_key
61-
msg = MessageParser.build(roborock_message, local_key=local_key)
59+
msg = self._encoder(roborock_message)
6260
if method:
6361
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
6462
# Send the command to the Roborock device

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ..containers import DeviceData, UserData
1212
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
13-
from ..protocol import MessageParser, Utils
13+
from ..protocol import Utils
1414
from ..roborock_message import (
1515
RoborockMessage,
1616
RoborockMessageProtocol,
@@ -47,9 +47,7 @@ async def send_message(self, roborock_message: RoborockMessage):
4747
response_protocol = (
4848
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
4949
)
50-
51-
local_key = self.device_info.device.local_key
52-
msg = MessageParser.build(roborock_message, local_key, False)
50+
msg = self._encoder(roborock_message)
5351
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
5452
async_response = self._async_response(request_id, response_protocol)
5553
self._send_msg_raw(msg)

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from roborock.cloud_api import RoborockMqttClient
1010
from roborock.containers import DeviceData, RoborockCategory, UserData
1111
from roborock.exceptions import RoborockException
12-
from roborock.protocol import MessageParser
1312
from roborock.roborock_message import (
1413
RoborockDyadDataProtocol,
1514
RoborockMessage,
@@ -43,8 +42,7 @@ async def send_message(self, roborock_message: RoborockMessage):
4342
await self.validate_connection()
4443
response_protocol = RoborockMessageProtocol.RPC_RESPONSE
4544

46-
local_key = self.device_info.device.local_key
47-
m = MessageParser.build(roborock_message, local_key, prefixed=False)
45+
m = self._encoder(roborock_message)
4846
# self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
4947
payload = json.loads(unpad(roborock_message.payload, AES.block_size))
5048
futures = []

0 commit comments

Comments
 (0)