Skip to content

Commit cc1b895

Browse files
committed
retry quorum
Summary: - we currently don't retry quorum requests from the manager to lighthouse - if lighthouse crashes, this can result in all replicas crashing - so add retries configurable through env var
1 parent ab11bce commit cc1b895

File tree

4 files changed

+96
-16
lines changed

4 files changed

+96
-16
lines changed

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ fn num_threads() -> usize {
7171
/// world_size (int): The world size of the replica group.
7272
/// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
7373
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
74+
/// quorum_retries (int): The number of retries for quorum requests to lighthouse server.
7475
#[pyclass]
7576
struct ManagerServer {
7677
handle: JoinHandle<Result<()>>,
@@ -91,6 +92,7 @@ impl ManagerServer {
9192
world_size: u64,
9293
heartbeat_interval: Duration,
9394
connect_timeout: Duration,
95+
quorum_retries: i64,
9496
) -> PyResult<Self> {
9597
py.allow_threads(move || {
9698
let runtime = tokio::runtime::Builder::new_multi_thread()
@@ -108,6 +110,7 @@ impl ManagerServer {
108110
world_size,
109111
heartbeat_interval,
110112
connect_timeout,
113+
quorum_retries,
111114
))
112115
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
113116
let handle = runtime.spawn(manager.clone().run());

src/manager.rs

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::collections::HashSet;
1010
use std::sync::Arc;
1111
use std::time::Duration;
1212

13+
use crate::torchftpb::LighthouseQuorumResponse;
1314
use anyhow::Result;
1415
use tokio::sync::broadcast;
1516
use tokio::sync::Mutex;
@@ -60,6 +61,8 @@ struct ManagerState {
6061
should_commit_channel: broadcast::Sender<bool>,
6162
should_commit_failures: HashSet<i64>,
6263
should_commit_count: HashSet<i64>,
64+
65+
lighthouse_client: LighthouseServiceClient<Channel>,
6366
}
6467

6568
pub struct Manager {
@@ -71,7 +74,9 @@ pub struct Manager {
7174
listener: Mutex<Option<tokio::net::TcpListener>>,
7275
local_addr: SocketAddr,
7376
heartbeat_interval: Duration,
74-
lighthouse_client: LighthouseServiceClient<Channel>,
77+
lighthouse_addr: String,
78+
connect_timeout: Duration,
79+
quorum_retries: i64,
7580
}
7681

7782
pub async fn manager_client_new(
@@ -108,6 +113,7 @@ impl Manager {
108113
world_size: u64,
109114
heartbeat_interval: Duration,
110115
connect_timeout: Duration,
116+
quorum_retries: i64,
111117
) -> Result<Arc<Self>> {
112118
let listener = tokio::net::TcpListener::bind(&bind).await?;
113119
let local_addr = listener.local_addr()?;
@@ -119,7 +125,8 @@ impl Manager {
119125

120126
Ok(Arc::new(Self {
121127
replica_id: replica_id,
122-
lighthouse_client: client,
128+
lighthouse_addr,
129+
connect_timeout,
123130
hostname: hostname,
124131
store_address: store_addr,
125132
world_size: world_size,
@@ -132,9 +139,12 @@ impl Manager {
132139
should_commit_channel: should_commit_tx,
133140
should_commit_count: HashSet::new(),
134141
should_commit_failures: HashSet::new(),
142+
143+
lighthouse_client: client,
135144
}),
136145
local_addr: local_addr,
137146
listener: Mutex::new(Some(listener)),
147+
quorum_retries,
138148
}))
139149
}
140150

@@ -170,8 +180,12 @@ impl Manager {
170180
}
171181

172182
async fn _run_heartbeat(self: Arc<Self>) -> Result<()> {
173-
let mut client = self.lighthouse_client.clone();
174183
loop {
184+
let mut client = {
185+
let state = self.state.lock().await;
186+
state.lighthouse_client.clone()
187+
};
188+
175189
let request = tonic::Request::new(LighthouseHeartbeatRequest {
176190
replica_id: self.replica_id.clone(),
177191
});
@@ -197,21 +211,13 @@ impl Manager {
197211

198212
// TODO: don't hold the lock during quorum
199213

200-
let mut client = self.lighthouse_client.clone();
201-
202-
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
214+
let lighthouse_request = LighthouseQuorumRequest {
203215
requester: Some(requester),
204-
});
205-
lighthouse_request.set_timeout(timeout);
216+
};
206217

207-
let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request))
208-
.await
209-
.unwrap_or_else(|e| {
210-
Err(Status::cancelled(format!(
211-
"lighthouse quorum timed out: {}",
212-
e.to_string()
213-
)))
214-
})?;
218+
let response = self
219+
._quorum_with_retries(state, timeout, lighthouse_request)
220+
.await?;
215221
let resp = response.into_inner();
216222

217223
info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp);
@@ -226,6 +232,66 @@ impl Manager {
226232

227233
Ok(())
228234
}
235+
236+
async fn _quorum_with_retries(
237+
&self,
238+
state: &mut ManagerState,
239+
timeout: Duration,
240+
lighthouse_request: LighthouseQuorumRequest,
241+
) -> Result<tonic::Response<LighthouseQuorumResponse>, Status> {
242+
let mut client = state.lighthouse_client.clone();
243+
244+
let mut retry_count = 0;
245+
loop {
246+
let mut request = tonic::Request::new(lighthouse_request.clone());
247+
request.set_timeout(timeout);
248+
249+
let result = tokio::time::timeout(timeout, client.quorum(request)).await;
250+
251+
match result {
252+
Ok(response) => {
253+
return response;
254+
}
255+
Err(e) => {
256+
info_with_replica!(
257+
self.replica_id,
258+
"lighthouse quorum failed. error: {}",
259+
e.to_string()
260+
);
261+
262+
if retry_count == self.quorum_retries {
263+
return Err(Status::internal(format!(
264+
"lighthouse quorum failed after {} retries. error: {}",
265+
retry_count,
266+
e.to_string(),
267+
)));
268+
}
269+
270+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
271+
272+
// Reset the client since the lighthouse server might have failed
273+
// If this also fails, consider increasing `connect_timeout`.
274+
let lighthouse_client =
275+
lighthouse_client_new(self.lighthouse_addr.clone(), self.connect_timeout)
276+
.await;
277+
278+
match lighthouse_client {
279+
Ok(client) => {
280+
state.lighthouse_client = client;
281+
}
282+
Err(e) => {
283+
return Err(Status::internal(format!(
284+
"Failed to connect to lighthouse. error: {}",
285+
e.to_string(),
286+
)));
287+
}
288+
}
289+
290+
retry_count += 1;
291+
}
292+
}
293+
}
294+
}
229295
}
230296

231297
#[tonic::async_trait]
@@ -563,6 +629,7 @@ mod tests {
563629
2, // world size
564630
Duration::from_millis(100), // heartbeat interval
565631
Duration::from_secs(10), // connect timeout
632+
0, // quorum retries
566633
)
567634
.await?;
568635
let manager_fut = tokio::spawn(manager._run_grpc());
@@ -610,6 +677,7 @@ mod tests {
610677
1, // world size
611678
Duration::from_millis(100), // heartbeat interval
612679
Duration::from_secs(10), // connect timeout
680+
0, // quorum retries
613681
)
614682
.await?;
615683
let manager_fut = tokio::spawn(manager.clone().run());
@@ -671,6 +739,7 @@ mod tests {
671739
1, // world size
672740
Duration::from_millis(100), // heartbeat interval
673741
Duration::from_secs(10), // connect timeout
742+
0, // quorum retries
674743
)
675744
.await?;
676745
let manager_fut = tokio::spawn(manager.clone().run());
@@ -737,6 +806,7 @@ mod tests {
737806
1, // world size
738807
Duration::from_millis(100), // heartbeat interval
739808
Duration::from_secs(10), // connect timeout
809+
0, // quorum retries
740810
)
741811
.await?;
742812
let manager_fut = tokio::spawn(manager.clone().run());

torchft/_torchft.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class ManagerServer:
4848
world_size: int,
4949
heartbeat_interval: timedelta,
5050
connect_timeout: timedelta,
51+
quorum_retries: int,
5152
) -> None: ...
5253
def address(self) -> str: ...
5354
def shutdown(self) -> None: ...

torchft/manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@
6767
QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC"
6868
CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC"
6969

70+
# Environment variable for the number of retries to use for the quorum.
71+
# We need to retry quorum in case lighthouse fails. Otherwise, if we
72+
# crash if call to quorum fails, all replicas will crash.
73+
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
74+
7075
T = TypeVar("T")
7176

7277

@@ -277,6 +282,7 @@ def __init__(
277282
world_size=group_world_size,
278283
heartbeat_interval=heartbeat_interval,
279284
connect_timeout=connect_timeout,
285+
quorum_retries=int(os.environ.get(QUORUM_RETRIES_ENV, "0")),
280286
)
281287

282288
self._store.set(MANAGER_ADDR_KEY, self._manager.address())

0 commit comments

Comments
 (0)