Skip to content

retry quorum #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ fn num_threads() -> usize {
/// world_size (int): The world size of the replica group.
/// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
/// quorum_retries (int): The number of retries for quorum requests to lighthouse server.
#[pyclass]
struct ManagerServer {
handle: JoinHandle<Result<()>>,
Expand All @@ -91,6 +92,7 @@ impl ManagerServer {
world_size: u64,
heartbeat_interval: Duration,
connect_timeout: Duration,
quorum_retries: i64,
) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = tokio::runtime::Builder::new_multi_thread()
Expand All @@ -108,6 +110,7 @@ impl ManagerServer {
world_size,
heartbeat_interval,
connect_timeout,
quorum_retries,
))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let handle = runtime.spawn(manager.clone().run());
Expand Down
141 changes: 114 additions & 27 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::net::connect;
use crate::timeout::try_parse_grpc_timeout;
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
use crate::torchftpb::manager_service_client::ManagerServiceClient;
use crate::torchftpb::LighthouseQuorumResponse;
use crate::torchftpb::{
manager_service_server::{ManagerService, ManagerServiceServer},
CheckpointMetadataRequest, CheckpointMetadataResponse, KillRequest, KillResponse,
Expand Down Expand Up @@ -60,6 +61,8 @@ struct ManagerState {
should_commit_channel: broadcast::Sender<bool>,
should_commit_failures: HashSet<i64>,
should_commit_count: HashSet<i64>,

lighthouse_client: LighthouseServiceClient<Channel>,
}

pub struct Manager {
Expand All @@ -71,7 +74,9 @@ pub struct Manager {
listener: Mutex<Option<tokio::net::TcpListener>>,
local_addr: SocketAddr,
heartbeat_interval: Duration,
lighthouse_client: LighthouseServiceClient<Channel>,
lighthouse_addr: String,
connect_timeout: Duration,
quorum_retries: i64,
}

pub async fn manager_client_new(
Expand Down Expand Up @@ -108,6 +113,7 @@ impl Manager {
world_size: u64,
heartbeat_interval: Duration,
connect_timeout: Duration,
quorum_retries: i64,
) -> Result<Arc<Self>> {
let listener = tokio::net::TcpListener::bind(&bind).await?;
let local_addr = listener.local_addr()?;
Expand All @@ -119,7 +125,8 @@ impl Manager {

Ok(Arc::new(Self {
replica_id: replica_id,
lighthouse_client: client,
lighthouse_addr,
connect_timeout,
hostname: hostname,
store_address: store_addr,
world_size: world_size,
Expand All @@ -132,9 +139,12 @@ impl Manager {
should_commit_channel: should_commit_tx,
should_commit_count: HashSet::new(),
should_commit_failures: HashSet::new(),

lighthouse_client: client,
}),
local_addr: local_addr,
listener: Mutex::new(Some(listener)),
quorum_retries,
}))
}

Expand Down Expand Up @@ -170,52 +180,50 @@ impl Manager {
}

async fn _run_heartbeat(self: Arc<Self>) -> Result<()> {
let mut client = self.lighthouse_client.clone();
loop {
let mut client = {
let state = self.state.lock().await;
state.lighthouse_client.clone()
};

let request = tonic::Request::new(LighthouseHeartbeatRequest {
replica_id: self.replica_id.clone(),
});

let _response = client.heartbeat(request).await;
if let Err(e) = client.heartbeat(request).await {
info_with_replica!(
self.replica_id,
"Failed to send heartbeat to lighthouse: {}",
e.to_string()
);
let _ = self.create_lighthouse_client().await;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be checking status for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's fine I think, if we were unable to replace the lighthouse client, next attempt to send the heartbeat will also fail and it'll again try to create a lighthouse client

}

sleep(self.heartbeat_interval).await;
}
}

async fn _run_quorum(
&self,
state: &mut ManagerState,
self: Arc<Self>,
requester: QuorumMember,
timeout: Duration,
) -> Result<(), Status> {
if (state.participants.len() as u64) < self.world_size {
return Ok(());
}

state.participants.clear();
info_with_replica!(self.replica_id, "All workers joined - starting quorum");

// TODO: don't hold the lock during quorum

let mut client = self.lighthouse_client.clone();

let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
let lighthouse_request = LighthouseQuorumRequest {
requester: Some(requester),
});
lighthouse_request.set_timeout(timeout);
};

let response = self
._quorum_with_retries(timeout, lighthouse_request)
.await?;

let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request))
.await
.unwrap_or_else(|e| {
Err(Status::cancelled(format!(
"lighthouse quorum timed out: {}",
e.to_string()
)))
})?;
let resp = response.into_inner();

info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp);

let state = self.state.lock().await;
// TODO: We don't broadcast in cases when this method returns an error, resulting in a hang
state
.channel
.send(
Expand All @@ -226,6 +234,75 @@ impl Manager {

Ok(())
}

async fn _quorum_with_retries(
&self,
timeout: Duration,
lighthouse_request: LighthouseQuorumRequest,
) -> Result<tonic::Response<LighthouseQuorumResponse>, Status> {
let mut retry_count = 0;
loop {
let mut client = {
let state = self.state.lock().await;
state.lighthouse_client.clone()
};

let mut request = tonic::Request::new(lighthouse_request.clone());
request.set_timeout(timeout);

let result = tokio::time::timeout(timeout, client.quorum(request)).await;

match result {
Ok(response) => {
return response;
}
Err(e) => {
info_with_replica!(
self.replica_id,
"lighthouse quorum failed. error: {}",
e.to_string()
);

if retry_count == self.quorum_retries {
return Err(Status::internal(format!(
"lighthouse quorum failed after {} retries. error: {}",
retry_count,
e.to_string(),
)));
}

tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

// Reset the client since the lighthouse server might have failed
// If this also fails, consider increasing `connect_timeout`.
let _ = self.create_lighthouse_client().await;

retry_count += 1;
}
}
}
}

async fn create_lighthouse_client(&self) -> Result<(), Status> {
// Reset the client since the lighthouse server might have failed
// If this also fails, consider increasing `connect_timeout`.
let lighthouse_client =
lighthouse_client_new(self.lighthouse_addr.clone(), self.connect_timeout).await;

match lighthouse_client {
Ok(client) => {
let mut state = self.state.lock().await;
state.lighthouse_client = client;
return Ok(());
}
Err(e) => {
return Err(Status::internal(format!(
"Failed to connect to lighthouse. error: {}",
e.to_string(),
)));
}
}
}
}

#[tonic::async_trait]
Expand Down Expand Up @@ -275,7 +352,13 @@ impl ManagerService for Arc<Manager> {
state.participants.insert(group_rank, member.clone());
let rx = state.channel.subscribe();

self._run_quorum(&mut state, member, timeout).await?;
if (state.participants.len() as u64) == self.world_size {
state.participants.clear();
let self_clone = self.clone();
tokio::spawn(async move {
let _ = self_clone._run_quorum(member, timeout).await;
});
}

rx
};
Expand Down Expand Up @@ -563,6 +646,7 @@ mod tests {
2, // world size
Duration::from_millis(100), // heartbeat interval
Duration::from_secs(10), // connect timeout
0, // quorum retries
)
.await?;
let manager_fut = tokio::spawn(manager._run_grpc());
Expand Down Expand Up @@ -610,6 +694,7 @@ mod tests {
1, // world size
Duration::from_millis(100), // heartbeat interval
Duration::from_secs(10), // connect timeout
0, // quorum retries
)
.await?;
let manager_fut = tokio::spawn(manager.clone().run());
Expand Down Expand Up @@ -671,6 +756,7 @@ mod tests {
1, // world size
Duration::from_millis(100), // heartbeat interval
Duration::from_secs(10), // connect timeout
0, // quorum retries
)
.await?;
let manager_fut = tokio::spawn(manager.clone().run());
Expand Down Expand Up @@ -737,6 +823,7 @@ mod tests {
1, // world size
Duration::from_millis(100), // heartbeat interval
Duration::from_secs(10), // connect timeout
0, // quorum retries
)
.await?;
let manager_fut = tokio::spawn(manager.clone().run());
Expand Down
1 change: 1 addition & 0 deletions torchft/_torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ManagerServer:
world_size: int,
heartbeat_interval: timedelta,
connect_timeout: timedelta,
quorum_retries: int,
) -> None: ...
def address(self) -> str: ...
def shutdown(self) -> None: ...
Expand Down
12 changes: 12 additions & 0 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC"
CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC"

# Environment variable for the number of retries to use for the quorum.
# We need to retry quorum in case lighthouse fails. Otherwise, if we
# crash if call to quorum fails, all replicas will crash.
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"

T = TypeVar("T")


Expand Down Expand Up @@ -150,6 +155,7 @@ def __init__(
checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None,
init_sync: bool = True,
max_retries: Optional[int] = None,
quorum_retries: int = 0,
) -> None:
"""
Args:
Expand Down Expand Up @@ -192,6 +198,7 @@ def __init__(
``torch.set_seed`` you should set this to False.
max_retries: the maximum number of consecutive should_commit failures to allow
before raising an exception. If None, will retry indefinitely.
quorum_retries: the number of times to retry the quorum before crashing
"""
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
Expand All @@ -217,6 +224,10 @@ def __init__(
self._max_retries = max_retries
self._commit_failures = 0

self._quorum_retries: int = int(
os.environ.get(QUORUM_RETRIES_ENV, str(quorum_retries))
)

store_addr = store_addr or os.environ["MASTER_ADDR"]
store_port = store_port or int(os.environ["MASTER_PORT"])
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
Expand Down Expand Up @@ -277,6 +288,7 @@ def __init__(
world_size=group_world_size,
heartbeat_interval=heartbeat_interval,
connect_timeout=connect_timeout,
quorum_retries=self._quorum_retries,
)

self._store.set(MANAGER_ADDR_KEY, self._manager.address())
Expand Down