Skip to content

Commit 0931e7e

Browse files
committed
retry quorum
Summary: - we currently don't retry quorum requests from the manager to lighthouse - if lighthouse crashes, this can result in all replicas crashing - so add retries configurable through env var - remove holding state lock when making network calls in manager - the manager tries reconnecting to lighthouse if a response from lighthouse fails up to configured number of retries - there's still some unhandled cases - manager doesn't broadcast the result to all ranks if there's a failure in `_run_quorum`, resulting in a hang - if a rank gets error from quorum, it still crashes (the handling will be more complicated if ranks are on multiple hosts and they can independently reconnect)
1 parent 347fd32 commit 0931e7e

File tree

4 files changed

+126
-27
lines changed

4 files changed

+126
-27
lines changed

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ fn num_threads() -> usize {
7171
/// world_size (int): The world size of the replica group.
7272
/// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
7373
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
74+
/// quorum_retries (int): The number of retries for quorum requests to lighthouse server.
7475
#[pyclass]
7576
struct ManagerServer {
7677
handle: JoinHandle<Result<()>>,
@@ -91,6 +92,7 @@ impl ManagerServer {
9192
world_size: u64,
9293
heartbeat_interval: Duration,
9394
connect_timeout: Duration,
95+
quorum_retries: i64,
9496
) -> PyResult<Self> {
9597
py.allow_threads(move || {
9698
let runtime = tokio::runtime::Builder::new_multi_thread()
@@ -108,6 +110,7 @@ impl ManagerServer {
108110
world_size,
109111
heartbeat_interval,
110112
connect_timeout,
113+
quorum_retries,
111114
))
112115
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
113116
let handle = runtime.spawn(manager.clone().run());

src/manager.rs

Lines changed: 114 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::net::connect;
2424
use crate::timeout::try_parse_grpc_timeout;
2525
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
2626
use crate::torchftpb::manager_service_client::ManagerServiceClient;
27+
use crate::torchftpb::LighthouseQuorumResponse;
2728
use crate::torchftpb::{
2829
manager_service_server::{ManagerService, ManagerServiceServer},
2930
CheckpointMetadataRequest, CheckpointMetadataResponse, KillRequest, KillResponse,
@@ -60,6 +61,8 @@ struct ManagerState {
6061
should_commit_channel: broadcast::Sender<bool>,
6162
should_commit_failures: HashSet<i64>,
6263
should_commit_count: HashSet<i64>,
64+
65+
lighthouse_client: LighthouseServiceClient<Channel>,
6366
}
6467

6568
pub struct Manager {
@@ -71,7 +74,9 @@ pub struct Manager {
7174
listener: Mutex<Option<tokio::net::TcpListener>>,
7275
local_addr: SocketAddr,
7376
heartbeat_interval: Duration,
74-
lighthouse_client: LighthouseServiceClient<Channel>,
77+
lighthouse_addr: String,
78+
connect_timeout: Duration,
79+
quorum_retries: i64,
7580
}
7681

7782
pub async fn manager_client_new(
@@ -108,6 +113,7 @@ impl Manager {
108113
world_size: u64,
109114
heartbeat_interval: Duration,
110115
connect_timeout: Duration,
116+
quorum_retries: i64,
111117
) -> Result<Arc<Self>> {
112118
let listener = tokio::net::TcpListener::bind(&bind).await?;
113119
let local_addr = listener.local_addr()?;
@@ -119,7 +125,8 @@ impl Manager {
119125

120126
Ok(Arc::new(Self {
121127
replica_id: replica_id,
122-
lighthouse_client: client,
128+
lighthouse_addr,
129+
connect_timeout,
123130
hostname: hostname,
124131
store_address: store_addr,
125132
world_size: world_size,
@@ -132,9 +139,12 @@ impl Manager {
132139
should_commit_channel: should_commit_tx,
133140
should_commit_count: HashSet::new(),
134141
should_commit_failures: HashSet::new(),
142+
143+
lighthouse_client: client,
135144
}),
136145
local_addr: local_addr,
137146
listener: Mutex::new(Some(listener)),
147+
quorum_retries,
138148
}))
139149
}
140150

@@ -170,52 +180,50 @@ impl Manager {
170180
}
171181

172182
async fn _run_heartbeat(self: Arc<Self>) -> Result<()> {
173-
let mut client = self.lighthouse_client.clone();
174183
loop {
184+
let mut client = {
185+
let state = self.state.lock().await;
186+
state.lighthouse_client.clone()
187+
};
188+
175189
let request = tonic::Request::new(LighthouseHeartbeatRequest {
176190
replica_id: self.replica_id.clone(),
177191
});
178192

179-
let _response = client.heartbeat(request).await;
193+
if let Err(e) = client.heartbeat(request).await {
194+
info_with_replica!(
195+
self.replica_id,
196+
"Failed to send heartbeat to lighthouse: {}",
197+
e.to_string()
198+
);
199+
let _ = self.create_lighthouse_client().await;
200+
}
180201

181202
sleep(self.heartbeat_interval).await;
182203
}
183204
}
184205

185206
async fn _run_quorum(
186-
&self,
187-
state: &mut ManagerState,
207+
self: Arc<Self>,
188208
requester: QuorumMember,
189209
timeout: Duration,
190210
) -> Result<(), Status> {
191-
if (state.participants.len() as u64) < self.world_size {
192-
return Ok(());
193-
}
194-
195-
state.participants.clear();
196211
info_with_replica!(self.replica_id, "All workers joined - starting quorum");
197212

198-
// TODO: don't hold the lock during quorum
199-
200-
let mut client = self.lighthouse_client.clone();
201-
202-
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
213+
let lighthouse_request = LighthouseQuorumRequest {
203214
requester: Some(requester),
204-
});
205-
lighthouse_request.set_timeout(timeout);
215+
};
216+
217+
let response = self
218+
._quorum_with_retries(timeout, lighthouse_request)
219+
.await?;
206220

207-
let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request))
208-
.await
209-
.unwrap_or_else(|e| {
210-
Err(Status::cancelled(format!(
211-
"lighthouse quorum timed out: {}",
212-
e.to_string()
213-
)))
214-
})?;
215221
let resp = response.into_inner();
216222

217223
info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp);
218224

225+
let state = self.state.lock().await;
226+
// TODO: We don't broadcast in cases when this method returns an error, resulting in a hang
219227
state
220228
.channel
221229
.send(
@@ -226,6 +234,75 @@ impl Manager {
226234

227235
Ok(())
228236
}
237+
238+
async fn _quorum_with_retries(
239+
&self,
240+
timeout: Duration,
241+
lighthouse_request: LighthouseQuorumRequest,
242+
) -> Result<tonic::Response<LighthouseQuorumResponse>, Status> {
243+
let mut client = {
244+
let state = self.state.lock().await;
245+
state.lighthouse_client.clone()
246+
};
247+
248+
let mut retry_count = 0;
249+
loop {
250+
let mut request = tonic::Request::new(lighthouse_request.clone());
251+
request.set_timeout(timeout);
252+
253+
let result = tokio::time::timeout(timeout, client.quorum(request)).await;
254+
255+
match result {
256+
Ok(response) => {
257+
return response;
258+
}
259+
Err(e) => {
260+
info_with_replica!(
261+
self.replica_id,
262+
"lighthouse quorum failed. error: {}",
263+
e.to_string()
264+
);
265+
266+
if retry_count == self.quorum_retries {
267+
return Err(Status::internal(format!(
268+
"lighthouse quorum failed after {} retries. error: {}",
269+
retry_count,
270+
e.to_string(),
271+
)));
272+
}
273+
274+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
275+
276+
// Reset the client since the lighthouse server might have failed
277+
// If this also fails, consider increasing `connect_timeout`.
278+
let _ = self.create_lighthouse_client().await;
279+
280+
retry_count += 1;
281+
}
282+
}
283+
}
284+
}
285+
286+
async fn create_lighthouse_client(&self) -> Result<(), Status> {
287+
// Reset the client since the lighthouse server might have failed
288+
// If this also fails, consider increasing `connect_timeout`.
289+
let lighthouse_client =
290+
lighthouse_client_new(self.lighthouse_addr.clone(), self.connect_timeout).await;
291+
292+
match lighthouse_client {
293+
Ok(client) => {
294+
let mut state = self.state.lock().await;
295+
state.lighthouse_client = client;
296+
return Ok(());
297+
}
298+
Err(e) => {
299+
return Err(Status::internal(format!(
300+
"Failed to connect to lighthouse. error: {}",
301+
e.to_string(),
302+
)));
303+
}
304+
}
305+
}
229306
}
230307

231308
#[tonic::async_trait]
@@ -275,7 +352,13 @@ impl ManagerService for Arc<Manager> {
275352
state.participants.insert(group_rank, member.clone());
276353
let rx = state.channel.subscribe();
277354

278-
self._run_quorum(&mut state, member, timeout).await?;
355+
if (state.participants.len() as u64) == self.world_size {
356+
state.participants.clear();
357+
let self_clone = self.clone();
358+
tokio::spawn(async move {
359+
let _ = self_clone._run_quorum(member, timeout).await;
360+
});
361+
}
279362

280363
rx
281364
};
@@ -563,6 +646,7 @@ mod tests {
563646
2, // world size
564647
Duration::from_millis(100), // heartbeat interval
565648
Duration::from_secs(10), // connect timeout
649+
0, // quorum retries
566650
)
567651
.await?;
568652
let manager_fut = tokio::spawn(manager._run_grpc());
@@ -610,6 +694,7 @@ mod tests {
610694
1, // world size
611695
Duration::from_millis(100), // heartbeat interval
612696
Duration::from_secs(10), // connect timeout
697+
0, // quorum retries
613698
)
614699
.await?;
615700
let manager_fut = tokio::spawn(manager.clone().run());
@@ -671,6 +756,7 @@ mod tests {
671756
1, // world size
672757
Duration::from_millis(100), // heartbeat interval
673758
Duration::from_secs(10), // connect timeout
759+
0, // quorum retries
674760
)
675761
.await?;
676762
let manager_fut = tokio::spawn(manager.clone().run());
@@ -737,6 +823,7 @@ mod tests {
737823
1, // world size
738824
Duration::from_millis(100), // heartbeat interval
739825
Duration::from_secs(10), // connect timeout
826+
0, // quorum retries
740827
)
741828
.await?;
742829
let manager_fut = tokio::spawn(manager.clone().run());

torchft/_torchft.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class ManagerServer:
4848
world_size: int,
4949
heartbeat_interval: timedelta,
5050
connect_timeout: timedelta,
51+
quorum_retries: int,
5152
) -> None: ...
5253
def address(self) -> str: ...
5354
def shutdown(self) -> None: ...

torchft/manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@
6767
QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC"
6868
CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC"
6969

70+
# Environment variable for the number of retries to use for the quorum.
71+
# We need to retry quorum in case lighthouse fails. Otherwise, if we
72+
# crash if call to quorum fails, all replicas will crash.
73+
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
74+
7075
T = TypeVar("T")
7176

7277

@@ -217,6 +222,8 @@ def __init__(
217222
self._max_retries = max_retries
218223
self._commit_failures = 0
219224

225+
self._quorum_retries: int = int(os.environ.get(QUORUM_RETRIES_ENV, "0"))
226+
220227
store_addr = store_addr or os.environ["MASTER_ADDR"]
221228
store_port = store_port or int(os.environ["MASTER_PORT"])
222229
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
@@ -277,6 +284,7 @@ def __init__(
277284
world_size=group_world_size,
278285
heartbeat_interval=heartbeat_interval,
279286
connect_timeout=connect_timeout,
287+
quorum_retries=self._quorum_retries,
280288
)
281289

282290
self._store.set(MANAGER_ADDR_KEY, self._manager.address())

0 commit comments

Comments
 (0)