diff --git a/src/lib.rs b/src/lib.rs index 32a7a37e..fc8f8eb5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,6 +71,7 @@ fn num_threads() -> usize { /// world_size (int): The world size of the replica group. /// heartbeat_interval (timedelta): The interval at which heartbeats are sent. /// connect_timeout (timedelta): The timeout for connecting to the lighthouse server. +/// quorum_retries (int): The number of retries for quorum requests to lighthouse server. #[pyclass] struct ManagerServer { handle: JoinHandle>, @@ -91,6 +92,7 @@ impl ManagerServer { world_size: u64, heartbeat_interval: Duration, connect_timeout: Duration, + quorum_retries: i64, ) -> PyResult { py.allow_threads(move || { let runtime = tokio::runtime::Builder::new_multi_thread() @@ -108,6 +110,7 @@ impl ManagerServer { world_size, heartbeat_interval, connect_timeout, + quorum_retries, )) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let handle = runtime.spawn(manager.clone().run()); diff --git a/src/manager.rs b/src/manager.rs index e28cbeb5..acb4b929 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -24,6 +24,7 @@ use crate::net::connect; use crate::timeout::try_parse_grpc_timeout; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; +use crate::torchftpb::LighthouseQuorumResponse; use crate::torchftpb::{ manager_service_server::{ManagerService, ManagerServiceServer}, CheckpointMetadataRequest, CheckpointMetadataResponse, KillRequest, KillResponse, @@ -60,6 +61,8 @@ struct ManagerState { should_commit_channel: broadcast::Sender, should_commit_failures: HashSet, should_commit_count: HashSet, + + lighthouse_client: LighthouseServiceClient, } pub struct Manager { @@ -71,7 +74,9 @@ pub struct Manager { listener: Mutex>, local_addr: SocketAddr, heartbeat_interval: Duration, - lighthouse_client: LighthouseServiceClient, + lighthouse_addr: String, + connect_timeout: Duration, + quorum_retries: i64, } pub async fn manager_client_new( @@ -108,6 +113,7 @@ impl Manager { world_size: u64, heartbeat_interval: Duration, connect_timeout: Duration, + quorum_retries: i64, ) -> Result> { let listener = tokio::net::TcpListener::bind(&bind).await?; let local_addr = listener.local_addr()?; @@ -119,7 +125,8 @@ impl Manager { Ok(Arc::new(Self { replica_id: replica_id, - lighthouse_client: client, + lighthouse_addr, + connect_timeout, hostname: hostname, store_address: store_addr, world_size: world_size, @@ -132,9 +139,12 @@ impl Manager { should_commit_channel: should_commit_tx, should_commit_count: HashSet::new(), should_commit_failures: HashSet::new(), + + lighthouse_client: client, }), local_addr: local_addr, listener: Mutex::new(Some(listener)), + quorum_retries, })) } @@ -170,52 +180,50 @@ impl Manager { } async fn _run_heartbeat(self: Arc) -> Result<()> { - let mut client = self.lighthouse_client.clone(); loop { + let mut client = { + let state = self.state.lock().await; + state.lighthouse_client.clone() + }; + let request = tonic::Request::new(LighthouseHeartbeatRequest { replica_id: self.replica_id.clone(), }); - let _response = client.heartbeat(request).await; + if let Err(e) = client.heartbeat(request).await { + info_with_replica!( + self.replica_id, + "Failed to send heartbeat to lighthouse: {}", + e.to_string() + ); + let _ = self.create_lighthouse_client().await; + } sleep(self.heartbeat_interval).await; } } async fn _run_quorum( - &self, - state: &mut ManagerState, + self: Arc, requester: QuorumMember, timeout: Duration, ) -> Result<(), Status> { - if (state.participants.len() as u64) < self.world_size { - return Ok(()); - } - - state.participants.clear(); info_with_replica!(self.replica_id, "All workers joined - starting quorum"); - // TODO: don't hold the lock during quorum - - let mut client = self.lighthouse_client.clone(); - - let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest { + let lighthouse_request = LighthouseQuorumRequest { requester: Some(requester), - }); - lighthouse_request.set_timeout(timeout); + }; + + let response = self + ._quorum_with_retries(timeout, lighthouse_request) + .await?; - let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request)) - .await - .unwrap_or_else(|e| { - Err(Status::cancelled(format!( - "lighthouse quorum timed out: {}", - e.to_string() - ))) - })?; let resp = response.into_inner(); info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp); + let state = self.state.lock().await; + // TODO: We don't broadcast in cases when this method returns an error, resulting in a hang state .channel .send( @@ -226,6 +234,75 @@ impl Manager { Ok(()) } + + async fn _quorum_with_retries( + &self, + timeout: Duration, + lighthouse_request: LighthouseQuorumRequest, + ) -> Result, Status> { + let mut retry_count = 0; + loop { + let mut client = { + let state = self.state.lock().await; + state.lighthouse_client.clone() + }; + + let mut request = tonic::Request::new(lighthouse_request.clone()); + request.set_timeout(timeout); + + let result = tokio::time::timeout(timeout, client.quorum(request)).await; + + match result { + Ok(response) => { + return response; + } + Err(e) => { + info_with_replica!( + self.replica_id, + "lighthouse quorum failed. error: {}", + e.to_string() + ); + + if retry_count == self.quorum_retries { + return Err(Status::internal(format!( + "lighthouse quorum failed after {} retries. error: {}", + retry_count, + e.to_string(), + ))); + } + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Reset the client since the lighthouse server might have failed + // If this also fails, consider increasing `connect_timeout`. + let _ = self.create_lighthouse_client().await; + + retry_count += 1; + } + } + } + } + + async fn create_lighthouse_client(&self) -> Result<(), Status> { + // Reset the client since the lighthouse server might have failed + // If this also fails, consider increasing `connect_timeout`. + let lighthouse_client = + lighthouse_client_new(self.lighthouse_addr.clone(), self.connect_timeout).await; + + match lighthouse_client { + Ok(client) => { + let mut state = self.state.lock().await; + state.lighthouse_client = client; + return Ok(()); + } + Err(e) => { + return Err(Status::internal(format!( + "Failed to connect to lighthouse. error: {}", + e.to_string(), + ))); + } + } + } } #[tonic::async_trait] @@ -275,7 +352,13 @@ impl ManagerService for Arc { state.participants.insert(group_rank, member.clone()); let rx = state.channel.subscribe(); - self._run_quorum(&mut state, member, timeout).await?; + if (state.participants.len() as u64) == self.world_size { + state.participants.clear(); + let self_clone = self.clone(); + tokio::spawn(async move { + let _ = self_clone._run_quorum(member, timeout).await; + }); + } rx }; @@ -563,6 +646,7 @@ mod tests { 2, // world size Duration::from_millis(100), // heartbeat interval Duration::from_secs(10), // connect timeout + 0, // quorum retries ) .await?; let manager_fut = tokio::spawn(manager._run_grpc()); @@ -610,6 +694,7 @@ mod tests { 1, // world size Duration::from_millis(100), // heartbeat interval Duration::from_secs(10), // connect timeout + 0, // quorum retries ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); @@ -671,6 +756,7 @@ mod tests { 1, // world size Duration::from_millis(100), // heartbeat interval Duration::from_secs(10), // connect timeout + 0, // quorum retries ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); @@ -737,6 +823,7 @@ mod tests { 1, // world size Duration::from_millis(100), // heartbeat interval Duration::from_secs(10), // connect timeout + 0, // quorum retries ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 9614d1b0..ff175bf0 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -48,6 +48,7 @@ class ManagerServer: world_size: int, heartbeat_interval: timedelta, connect_timeout: timedelta, + quorum_retries: int, ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... diff --git a/torchft/manager.py b/torchft/manager.py index 07d37453..e01a965e 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -67,6 +67,11 @@ QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC" CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC" +# Environment variable for the number of retries to use for the quorum. +# We need to retry quorum in case lighthouse fails. Otherwise, if we +# crash if call to quorum fails, all replicas will crash. +QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES" + T = TypeVar("T") @@ -150,6 +155,7 @@ def __init__( checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, init_sync: bool = True, max_retries: Optional[int] = None, + quorum_retries: int = 0, ) -> None: """ Args: @@ -192,6 +198,7 @@ def __init__( ``torch.set_seed`` you should set this to False. max_retries: the maximum number of consecutive should_commit failures to allow before raising an exception. If None, will retry indefinitely. + quorum_retries: the number of times to retry the quorum before crashing """ self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {} self._user_state_dicts: Dict[str, Callable[[], object]] = {} @@ -217,6 +224,10 @@ def __init__( self._max_retries = max_retries self._commit_failures = 0 + self._quorum_retries: int = int( + os.environ.get(QUORUM_RETRIES_ENV, str(quorum_retries)) + ) + store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) self._group_rank: int = rank if rank is not None else int(os.environ["RANK"]) @@ -277,6 +288,7 @@ def __init__( world_size=group_world_size, heartbeat_interval=heartbeat_interval, connect_timeout=connect_timeout, + quorum_retries=self._quorum_retries, ) self._store.set(MANAGER_ADDR_KEY, self._manager.address())