Skip to content

Commit 1fb135b

Browse files
authored
feat: Add a local channel, similar to the MQTT channel (#410)
* feat: Add a local channel, similar to the MQTT channel * feat: Log a warning when transport is already closed
1 parent 509ff6a commit 1fb135b

File tree

2 files changed

+472
-0
lines changed

2 files changed

+472
-0
lines changed

roborock/devices/local_channel.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Module for communicating with Roborock devices over a local network."""
2+
3+
import asyncio
4+
import logging
5+
from collections.abc import Callable
6+
from dataclasses import dataclass
7+
from json import JSONDecodeError
8+
9+
from roborock.exceptions import RoborockConnectionException, RoborockException
10+
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
11+
from roborock.roborock_message import RoborockMessage
12+
13+
_LOGGER = logging.getLogger(__name__)
14+
_PORT = 58867
15+
16+
17+
@dataclass
18+
class _LocalProtocol(asyncio.Protocol):
19+
"""Callbacks for the Roborock local client transport."""
20+
21+
messages_cb: Callable[[bytes], None]
22+
connection_lost_cb: Callable[[Exception | None], None]
23+
24+
def data_received(self, data: bytes) -> None:
25+
"""Called when data is received from the transport."""
26+
self.messages_cb(data)
27+
28+
def connection_lost(self, exc: Exception | None) -> None:
29+
"""Called when the transport connection is lost."""
30+
self.connection_lost_cb(exc)
31+
32+
33+
class LocalChannel:
34+
"""Simple RPC-style channel for communicating with a device over a local network.
35+
36+
Handles request/response correlation and timeouts, but leaves message
37+
format most parsing to higher-level components.
38+
"""
39+
40+
def __init__(self, host: str, local_key: str):
41+
self._host = host
42+
self._transport: asyncio.Transport | None = None
43+
self._protocol: _LocalProtocol | None = None
44+
self._subscribers: list[Callable[[RoborockMessage], None]] = []
45+
self._is_connected = False
46+
47+
# RPC support
48+
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
49+
self._decoder: Decoder = create_local_decoder(local_key)
50+
self._encoder: Encoder = create_local_encoder(local_key)
51+
self._queue_lock = asyncio.Lock()
52+
53+
async def connect(self) -> None:
54+
"""Connect to the device."""
55+
if self._is_connected:
56+
_LOGGER.warning("Already connected")
57+
return
58+
_LOGGER.debug("Connecting to %s:%s", self._host, _PORT)
59+
loop = asyncio.get_running_loop()
60+
protocol = _LocalProtocol(self._data_received, self._connection_lost)
61+
try:
62+
self._transport, self._protocol = await loop.create_connection(lambda: protocol, self._host, _PORT)
63+
self._is_connected = True
64+
except OSError as e:
65+
raise RoborockConnectionException(f"Failed to connect to {self._host}:{_PORT}") from e
66+
67+
async def close(self) -> None:
68+
"""Disconnect from the device."""
69+
if self._transport:
70+
self._transport.close()
71+
else:
72+
_LOGGER.warning("Close called but transport is already None")
73+
self._transport = None
74+
self._is_connected = False
75+
76+
def _data_received(self, data: bytes) -> None:
77+
"""Handle incoming data from the transport."""
78+
if not (messages := self._decoder(data)):
79+
_LOGGER.warning("Failed to decode local message: %s", data)
80+
return
81+
for message in messages:
82+
_LOGGER.debug("Received message: %s", message)
83+
asyncio.create_task(self._resolve_future_with_lock(message))
84+
for callback in self._subscribers:
85+
try:
86+
callback(message)
87+
except Exception as e:
88+
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
89+
90+
def _connection_lost(self, exc: Exception | None) -> None:
91+
"""Handle connection loss."""
92+
_LOGGER.warning("Connection lost to %s", self._host, exc_info=exc)
93+
self._transport = None
94+
self._is_connected = False
95+
96+
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
97+
"""Subscribe to all messages from the device."""
98+
self._subscribers.append(callback)
99+
100+
def unsubscribe() -> None:
101+
self._subscribers.remove(callback)
102+
103+
return unsubscribe
104+
105+
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
106+
"""Resolve waiting future with proper locking."""
107+
if (request_id := message.get_request_id()) is None:
108+
_LOGGER.debug("Received message with no request_id")
109+
return
110+
async with self._queue_lock:
111+
if (future := self._waiting_queue.pop(request_id, None)) is not None:
112+
future.set_result(message)
113+
else:
114+
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
115+
116+
async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
117+
"""Send a command message and wait for the response message."""
118+
if not self._transport or not self._is_connected:
119+
raise RoborockConnectionException("Not connected to device")
120+
121+
try:
122+
if (request_id := message.get_request_id()) is None:
123+
raise RoborockException("Message must have a request_id for RPC calls")
124+
except (ValueError, JSONDecodeError) as err:
125+
_LOGGER.exception("Error getting request_id from message: %s", err)
126+
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
127+
128+
future: asyncio.Future[RoborockMessage] = asyncio.Future()
129+
async with self._queue_lock:
130+
if request_id in self._waiting_queue:
131+
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
132+
self._waiting_queue[request_id] = future
133+
134+
try:
135+
encoded_msg = self._encoder(message)
136+
self._transport.write(encoded_msg)
137+
return await asyncio.wait_for(future, timeout=timeout)
138+
except asyncio.TimeoutError as ex:
139+
async with self._queue_lock:
140+
self._waiting_queue.pop(request_id, None)
141+
raise RoborockException(f"Command timed out after {timeout}s") from ex
142+
except Exception:
143+
logging.exception("Uncaught error sending command")
144+
async with self._queue_lock:
145+
self._waiting_queue.pop(request_id, None)
146+
raise

0 commit comments

Comments
 (0)