Skip to content

Commit 9ae71da

Browse files
committed
retry quorum
Summary: - we currently don't retry quorum requests from the manager to lighthouse - if lighthouse crashes, this can result in all replicas crashing - so add retries configurable through env var
1 parent ab11bce commit 9ae71da

File tree

4 files changed

+70
-13
lines changed

4 files changed

+70
-13
lines changed

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ fn num_threads() -> usize {
7171
/// world_size (int): The world size of the replica group.
7272
/// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
7373
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
74+
/// quorum_retries (int): The number of retries for quorum requests to lighthouse server.
7475
#[pyclass]
7576
struct ManagerServer {
7677
handle: JoinHandle<Result<()>>,
@@ -91,6 +92,7 @@ impl ManagerServer {
9192
world_size: u64,
9293
heartbeat_interval: Duration,
9394
connect_timeout: Duration,
95+
quorum_retries: i64,
9496
) -> PyResult<Self> {
9597
py.allow_threads(move || {
9698
let runtime = tokio::runtime::Builder::new_multi_thread()
@@ -108,6 +110,7 @@ impl ManagerServer {
108110
world_size,
109111
heartbeat_interval,
110112
connect_timeout,
113+
quorum_retries,
111114
))
112115
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
113116
let handle = runtime.spawn(manager.clone().run());

src/manager.rs

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::collections::HashSet;
1010
use std::sync::Arc;
1111
use std::time::Duration;
1212

13+
use crate::torchftpb::LighthouseQuorumResponse;
1314
use anyhow::Result;
1415
use tokio::sync::broadcast;
1516
use tokio::sync::Mutex;
@@ -72,6 +73,8 @@ pub struct Manager {
7273
local_addr: SocketAddr,
7374
heartbeat_interval: Duration,
7475
lighthouse_client: LighthouseServiceClient<Channel>,
76+
lighthouse_addr: String,
77+
quorum_retries: i64,
7578
}
7679

7780
pub async fn manager_client_new(
@@ -108,6 +111,7 @@ impl Manager {
108111
world_size: u64,
109112
heartbeat_interval: Duration,
110113
connect_timeout: Duration,
114+
quorum_retries: i64,
111115
) -> Result<Arc<Self>> {
112116
let listener = tokio::net::TcpListener::bind(&bind).await?;
113117
let local_addr = listener.local_addr()?;
@@ -120,6 +124,7 @@ impl Manager {
120124
Ok(Arc::new(Self {
121125
replica_id: replica_id,
122126
lighthouse_client: client,
127+
lighthouse_addr,
123128
hostname: hostname,
124129
store_address: store_addr,
125130
world_size: world_size,
@@ -135,6 +140,7 @@ impl Manager {
135140
}),
136141
local_addr: local_addr,
137142
listener: Mutex::new(Some(listener)),
143+
quorum_retries,
138144
}))
139145
}
140146

@@ -197,21 +203,13 @@ impl Manager {
197203

198204
// TODO: don't hold the lock during quorum
199205

200-
let mut client = self.lighthouse_client.clone();
201-
202-
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
206+
let lighthouse_request = LighthouseQuorumRequest {
203207
requester: Some(requester),
204-
});
205-
lighthouse_request.set_timeout(timeout);
208+
};
206209

207-
let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request))
208-
.await
209-
.unwrap_or_else(|e| {
210-
Err(Status::cancelled(format!(
211-
"lighthouse quorum timed out: {}",
212-
e.to_string()
213-
)))
214-
})?;
210+
let response = self
211+
._quorum_with_retries(timeout, lighthouse_request)
212+
.await?;
215213
let resp = response.into_inner();
216214

217215
info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp);
@@ -226,6 +224,51 @@ impl Manager {
226224

227225
Ok(())
228226
}
227+
228+
async fn _quorum_with_retries(
229+
&self,
230+
timeout: Duration,
231+
lighthouse_request: LighthouseQuorumRequest,
232+
) -> Result<tonic::Response<LighthouseQuorumResponse>, Status> {
233+
let mut client = self.lighthouse_client.clone();
234+
235+
let mut retry_count = 0;
236+
loop {
237+
let mut request = tonic::Request::new(lighthouse_request.clone());
238+
request.set_timeout(timeout);
239+
240+
let result = tokio::time::timeout(timeout, client.quorum(request)).await;
241+
242+
match result {
243+
Ok(response) => {
244+
return response;
245+
}
246+
Err(e) => {
247+
info_with_replica!(
248+
self.replica_id,
249+
"lighthouse quorum failed. error: {}",
250+
e.to_string()
251+
);
252+
253+
if retry_count == self.quorum_retries {
254+
return Err(Status::internal(format!(
255+
"lighthouse quorum failed after {} retries. error: {}",
256+
retry_count,
257+
e.to_string(),
258+
)));
259+
}
260+
261+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
262+
263+
// Reset the client since the lighthouse server might have failed
264+
self.lighthouse_client =
265+
lighthouse_client_new(lighthouse_addr.clone(), connect_timeout).await?;
266+
267+
retry_count += 1;
268+
}
269+
}
270+
}
271+
}
229272
}
230273

231274
#[tonic::async_trait]
@@ -563,6 +606,7 @@ mod tests {
563606
2, // world size
564607
Duration::from_millis(100), // heartbeat interval
565608
Duration::from_secs(10), // connect timeout
609+
0, // quorum retries
566610
)
567611
.await?;
568612
let manager_fut = tokio::spawn(manager._run_grpc());
@@ -610,6 +654,7 @@ mod tests {
610654
1, // world size
611655
Duration::from_millis(100), // heartbeat interval
612656
Duration::from_secs(10), // connect timeout
657+
0, // quorum retries
613658
)
614659
.await?;
615660
let manager_fut = tokio::spawn(manager.clone().run());
@@ -671,6 +716,7 @@ mod tests {
671716
1, // world size
672717
Duration::from_millis(100), // heartbeat interval
673718
Duration::from_secs(10), // connect timeout
719+
0, // quorum retries
674720
)
675721
.await?;
676722
let manager_fut = tokio::spawn(manager.clone().run());
@@ -737,6 +783,7 @@ mod tests {
737783
1, // world size
738784
Duration::from_millis(100), // heartbeat interval
739785
Duration::from_secs(10), // connect timeout
786+
0, // quorum retries
740787
)
741788
.await?;
742789
let manager_fut = tokio::spawn(manager.clone().run());

torchft/_torchft.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class ManagerServer:
4848
world_size: int,
4949
heartbeat_interval: timedelta,
5050
connect_timeout: timedelta,
51+
quorum_retries: int,
5152
) -> None: ...
5253
def address(self) -> str: ...
5354
def shutdown(self) -> None: ...

torchft/manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@
6767
QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC"
6868
CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC"
6969

70+
# Environment variable for the number of retries to use for the quorum.
71+
# We need to retry quorum in case lighthouse fails. Otherwise, if we
72+
# crash if call to quorum fails, all replicas will crash.
73+
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
74+
7075
T = TypeVar("T")
7176

7277

@@ -277,6 +282,7 @@ def __init__(
277282
world_size=group_world_size,
278283
heartbeat_interval=heartbeat_interval,
279284
connect_timeout=connect_timeout,
285+
quorum_retries=int(os.environ.get(QUORUM_RETRIES_ENV, "0")),
280286
)
281287

282288
self._store.set(MANAGER_ADDR_KEY, self._manager.address())

0 commit comments

Comments
 (0)