@@ -24,6 +24,7 @@ use crate::net::connect;
24
24
use crate :: timeout:: try_parse_grpc_timeout;
25
25
use crate :: torchftpb:: lighthouse_service_client:: LighthouseServiceClient ;
26
26
use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
27
+ use crate :: torchftpb:: LighthouseQuorumResponse ;
27
28
use crate :: torchftpb:: {
28
29
manager_service_server:: { ManagerService , ManagerServiceServer } ,
29
30
CheckpointMetadataRequest , CheckpointMetadataResponse , KillRequest , KillResponse ,
@@ -60,6 +61,8 @@ struct ManagerState {
60
61
should_commit_channel : broadcast:: Sender < bool > ,
61
62
should_commit_failures : HashSet < i64 > ,
62
63
should_commit_count : HashSet < i64 > ,
64
+
65
+ lighthouse_client : LighthouseServiceClient < Channel > ,
63
66
}
64
67
65
68
pub struct Manager {
@@ -71,7 +74,9 @@ pub struct Manager {
71
74
listener : Mutex < Option < tokio:: net:: TcpListener > > ,
72
75
local_addr : SocketAddr ,
73
76
heartbeat_interval : Duration ,
74
- lighthouse_client : LighthouseServiceClient < Channel > ,
77
+ lighthouse_addr : String ,
78
+ connect_timeout : Duration ,
79
+ quorum_retries : i64 ,
75
80
}
76
81
77
82
pub async fn manager_client_new (
@@ -108,6 +113,7 @@ impl Manager {
108
113
world_size : u64 ,
109
114
heartbeat_interval : Duration ,
110
115
connect_timeout : Duration ,
116
+ quorum_retries : i64 ,
111
117
) -> Result < Arc < Self > > {
112
118
let listener = tokio:: net:: TcpListener :: bind ( & bind) . await ?;
113
119
let local_addr = listener. local_addr ( ) ?;
@@ -119,7 +125,8 @@ impl Manager {
119
125
120
126
Ok ( Arc :: new ( Self {
121
127
replica_id : replica_id,
122
- lighthouse_client : client,
128
+ lighthouse_addr,
129
+ connect_timeout,
123
130
hostname : hostname,
124
131
store_address : store_addr,
125
132
world_size : world_size,
@@ -132,9 +139,12 @@ impl Manager {
132
139
should_commit_channel : should_commit_tx,
133
140
should_commit_count : HashSet :: new ( ) ,
134
141
should_commit_failures : HashSet :: new ( ) ,
142
+
143
+ lighthouse_client : client,
135
144
} ) ,
136
145
local_addr : local_addr,
137
146
listener : Mutex :: new ( Some ( listener) ) ,
147
+ quorum_retries,
138
148
} ) )
139
149
}
140
150
@@ -170,52 +180,50 @@ impl Manager {
170
180
}
171
181
172
182
async fn _run_heartbeat ( self : Arc < Self > ) -> Result < ( ) > {
173
- let mut client = self . lighthouse_client . clone ( ) ;
174
183
loop {
184
+ let mut client = {
185
+ let state = self . state . lock ( ) . await ;
186
+ state. lighthouse_client . clone ( )
187
+ } ;
188
+
175
189
let request = tonic:: Request :: new ( LighthouseHeartbeatRequest {
176
190
replica_id : self . replica_id . clone ( ) ,
177
191
} ) ;
178
192
179
- let _response = client. heartbeat ( request) . await ;
193
+ if let Err ( e) = client. heartbeat ( request) . await {
194
+ info_with_replica ! (
195
+ self . replica_id,
196
+ "Failed to send heartbeat to lighthouse: {}" ,
197
+ e. to_string( )
198
+ ) ;
199
+ let _ = self . create_lighthouse_client ( ) . await ;
200
+ }
180
201
181
202
sleep ( self . heartbeat_interval ) . await ;
182
203
}
183
204
}
184
205
185
206
async fn _run_quorum (
186
- & self ,
187
- state : & mut ManagerState ,
207
+ self : Arc < Self > ,
188
208
requester : QuorumMember ,
189
209
timeout : Duration ,
190
210
) -> Result < ( ) , Status > {
191
- if ( state. participants . len ( ) as u64 ) < self . world_size {
192
- return Ok ( ( ) ) ;
193
- }
194
-
195
- state. participants . clear ( ) ;
196
211
info_with_replica ! ( self . replica_id, "All workers joined - starting quorum" ) ;
197
212
198
- // TODO: don't hold the lock during quorum
199
-
200
- let mut client = self . lighthouse_client . clone ( ) ;
201
-
202
- let mut lighthouse_request = tonic:: Request :: new ( LighthouseQuorumRequest {
213
+ let lighthouse_request = LighthouseQuorumRequest {
203
214
requester : Some ( requester) ,
204
- } ) ;
205
- lighthouse_request. set_timeout ( timeout) ;
215
+ } ;
216
+
217
+ let response = self
218
+ . _quorum_with_retries ( timeout, lighthouse_request)
219
+ . await ?;
206
220
207
- let response = tokio:: time:: timeout ( timeout, client. quorum ( lighthouse_request) )
208
- . await
209
- . unwrap_or_else ( |e| {
210
- Err ( Status :: cancelled ( format ! (
211
- "lighthouse quorum timed out: {}" ,
212
- e. to_string( )
213
- ) ) )
214
- } ) ?;
215
221
let resp = response. into_inner ( ) ;
216
222
217
223
info_with_replica ! ( self . replica_id, "got lighthouse quorum {:?}" , resp) ;
218
224
225
+ let state = self . state . lock ( ) . await ;
226
+ // TODO: We don't broadcast in cases when this method returns an error, resulting in a hang
219
227
state
220
228
. channel
221
229
. send (
@@ -226,6 +234,75 @@ impl Manager {
226
234
227
235
Ok ( ( ) )
228
236
}
237
+
238
+ async fn _quorum_with_retries (
239
+ & self ,
240
+ timeout : Duration ,
241
+ lighthouse_request : LighthouseQuorumRequest ,
242
+ ) -> Result < tonic:: Response < LighthouseQuorumResponse > , Status > {
243
+ let mut client = {
244
+ let state = self . state . lock ( ) . await ;
245
+ state. lighthouse_client . clone ( )
246
+ } ;
247
+
248
+ let mut retry_count = 0 ;
249
+ loop {
250
+ let mut request = tonic:: Request :: new ( lighthouse_request. clone ( ) ) ;
251
+ request. set_timeout ( timeout) ;
252
+
253
+ let result = tokio:: time:: timeout ( timeout, client. quorum ( request) ) . await ;
254
+
255
+ match result {
256
+ Ok ( response) => {
257
+ return response;
258
+ }
259
+ Err ( e) => {
260
+ info_with_replica ! (
261
+ self . replica_id,
262
+ "lighthouse quorum failed. error: {}" ,
263
+ e. to_string( )
264
+ ) ;
265
+
266
+ if retry_count == self . quorum_retries {
267
+ return Err ( Status :: internal ( format ! (
268
+ "lighthouse quorum failed after {} retries. error: {}" ,
269
+ retry_count,
270
+ e. to_string( ) ,
271
+ ) ) ) ;
272
+ }
273
+
274
+ tokio:: time:: sleep ( tokio:: time:: Duration :: from_millis ( 100 ) ) . await ;
275
+
276
+ // Reset the client since the lighthouse server might have failed
277
+ // If this also fails, consider increasing `connect_timeout`.
278
+ let _ = self . create_lighthouse_client ( ) . await ;
279
+
280
+ retry_count += 1 ;
281
+ }
282
+ }
283
+ }
284
+ }
285
+
286
+ async fn create_lighthouse_client ( & self ) -> Result < ( ) , Status > {
287
+ // Reset the client since the lighthouse server might have failed
288
+ // If this also fails, consider increasing `connect_timeout`.
289
+ let lighthouse_client =
290
+ lighthouse_client_new ( self . lighthouse_addr . clone ( ) , self . connect_timeout ) . await ;
291
+
292
+ match lighthouse_client {
293
+ Ok ( client) => {
294
+ let mut state = self . state . lock ( ) . await ;
295
+ state. lighthouse_client = client;
296
+ return Ok ( ( ) ) ;
297
+ }
298
+ Err ( e) => {
299
+ return Err ( Status :: internal ( format ! (
300
+ "Failed to connect to lighthouse. error: {}" ,
301
+ e. to_string( ) ,
302
+ ) ) ) ;
303
+ }
304
+ }
305
+ }
229
306
}
230
307
231
308
#[ tonic:: async_trait]
@@ -275,7 +352,13 @@ impl ManagerService for Arc<Manager> {
275
352
state. participants . insert ( group_rank, member. clone ( ) ) ;
276
353
let rx = state. channel . subscribe ( ) ;
277
354
278
- self . _run_quorum ( & mut state, member, timeout) . await ?;
355
+ if ( state. participants . len ( ) as u64 ) == self . world_size {
356
+ state. participants . clear ( ) ;
357
+ let self_clone = self . clone ( ) ;
358
+ tokio:: spawn ( async move {
359
+ let _ = self_clone. _run_quorum ( member, timeout) . await ;
360
+ } ) ;
361
+ }
279
362
280
363
rx
281
364
} ;
@@ -563,6 +646,7 @@ mod tests {
563
646
2 , // world size
564
647
Duration :: from_millis ( 100 ) , // heartbeat interval
565
648
Duration :: from_secs ( 10 ) , // connect timeout
649
+ 0 , // quorum retries
566
650
)
567
651
. await ?;
568
652
let manager_fut = tokio:: spawn ( manager. _run_grpc ( ) ) ;
@@ -610,6 +694,7 @@ mod tests {
610
694
1 , // world size
611
695
Duration :: from_millis ( 100 ) , // heartbeat interval
612
696
Duration :: from_secs ( 10 ) , // connect timeout
697
+ 0 , // quorum retries
613
698
)
614
699
. await ?;
615
700
let manager_fut = tokio:: spawn ( manager. clone ( ) . run ( ) ) ;
@@ -671,6 +756,7 @@ mod tests {
671
756
1 , // world size
672
757
Duration :: from_millis ( 100 ) , // heartbeat interval
673
758
Duration :: from_secs ( 10 ) , // connect timeout
759
+ 0 , // quorum retries
674
760
)
675
761
. await ?;
676
762
let manager_fut = tokio:: spawn ( manager. clone ( ) . run ( ) ) ;
@@ -737,6 +823,7 @@ mod tests {
737
823
1 , // world size
738
824
Duration :: from_millis ( 100 ) , // heartbeat interval
739
825
Duration :: from_secs ( 10 ) , // connect timeout
826
+ 0 , // quorum retries
740
827
)
741
828
. await ?;
742
829
let manager_fut = tokio:: spawn ( manager. clone ( ) . run ( ) ) ;
0 commit comments