diff --git a/src/client.rs b/src/client.rs index 7d8b9bc3c..24f46e5b6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,7 +7,10 @@ pub mod options; pub mod session; use std::{ - sync::{atomic::AtomicBool, Mutex as SyncMutex}, + sync::{ + atomic::{AtomicBool, Ordering}, + Mutex as SyncMutex, + }, time::{Duration, Instant}, }; @@ -26,13 +29,18 @@ use crate::trace::{ COMMAND_TRACING_EVENT_TARGET, }; use crate::{ + bson::doc, concern::{ReadConcern, WriteConcern}, db::Database, error::{Error, ErrorKind, Result}, event::command::CommandEvent, id_set::IdSet, options::{ClientOptions, DatabaseOptions, ReadPreference, SelectionCriteria, ServerAddress}, - sdam::{server_selection, SelectedServer, Topology}, + sdam::{ + server_selection::{self, attempt_to_select_server}, + SelectedServer, + Topology, + }, tracking_arc::TrackingArc, BoxFuture, ClientSession, @@ -123,6 +131,7 @@ struct ClientInner { options: ClientOptions, session_pool: ServerSessionPool, shutdown: Shutdown, + dropped: AtomicBool, #[cfg(feature = "in-use-encryption")] csfle: tokio::sync::RwLock>, #[cfg(test)] @@ -159,6 +168,7 @@ impl Client { pending_drops: SyncMutex::new(IdSet::new()), executed: AtomicBool::new(false), }, + dropped: AtomicBool::new(false), #[cfg(feature = "in-use-encryption")] csfle: Default::default(), #[cfg(test)] @@ -591,6 +601,40 @@ impl Client { pub(crate) fn options(&self) -> &ClientOptions { &self.inner.options } + + /// Ends all sessions contained in this client's session pool on the server. + pub(crate) async fn end_all_sessions(&self) { + // The maximum number of session IDs that should be sent in a single endSessions command. + const MAX_END_SESSIONS_BATCH_SIZE: usize = 10_000; + + let mut watcher = self.inner.topology.watch(); + let selection_criteria = + SelectionCriteria::from(ReadPreference::PrimaryPreferred { options: None }); + + let session_ids = self.inner.session_pool.get_session_ids().await; + for chunk in session_ids.chunks(MAX_END_SESSIONS_BATCH_SIZE) { + let state = watcher.observe_latest(); + let Ok(Some(_)) = attempt_to_select_server( + &selection_criteria, + &state.description, + &state.servers(), + None, + ) else { + // If a suitable server is not available, do not proceed with the operation to avoid + // spinning for server_selection_timeout. + return; + }; + + let end_sessions = doc! { + "endSessions": chunk, + }; + let _ = self + .database("admin") + .run_command(end_sessions) + .selection_criteria(selection_criteria.clone()) + .await; + } + } } #[derive(Clone, Debug)] @@ -625,3 +669,24 @@ impl AsyncDropToken { Self { tx: self.tx.take() } } } + +impl Drop for Client { + fn drop(&mut self) { + if !self.inner.shutdown.executed.load(Ordering::SeqCst) + && !self.inner.dropped.load(Ordering::SeqCst) + && TrackingArc::strong_count(&self.inner) == 1 + { + // We need an owned copy of the client to move into the spawned future. However, if this + // call to drop completes before the spawned future completes, the number of strong + // references to the inner client will again be 1 when the cloned client drops, and thus + // end_all_sessions will be called continuously until the runtime shuts down. Storing a + // flag indicating whether end_all_sessions has already been called breaks + // this cycle. + self.inner.dropped.store(true, Ordering::SeqCst); + let client = self.clone(); + crate::runtime::spawn(async move { + client.end_all_sessions().await; + }); + } + } +} diff --git a/src/client/action/shutdown.rs b/src/client/action/shutdown.rs index a672b26c8..7944342ac 100644 --- a/src/client/action/shutdown.rs +++ b/src/client/action/shutdown.rs @@ -23,6 +23,11 @@ impl Action for crate::action::Shutdown { .extract(); join_all(pending).await; } + // If shutdown has already been called on a different copy of the client, don't call + // end_all_sessions again. + if !self.client.inner.shutdown.executed.load(Ordering::SeqCst) { + self.client.end_all_sessions().await; + } self.client.inner.topology.shutdown().await; // This has to happen last to allow pending cleanup to execute commands. self.client diff --git a/src/client/session.rs b/src/client/session.rs index 1a9da856c..9dc3a3daf 100644 --- a/src/client/session.rs +++ b/src/client/session.rs @@ -401,7 +401,7 @@ impl Drop for ClientSession { #[derive(Clone, Debug)] pub(crate) struct ServerSession { /// The id of the server session to which this corresponds. - id: Document, + pub(crate) id: Document, /// The last time an operation was executed with this session. last_use: std::time::Instant, diff --git a/src/client/session/pool.rs b/src/client/session/pool.rs index 34c9990b2..3980d214e 100644 --- a/src/client/session/pool.rs +++ b/src/client/session/pool.rs @@ -3,7 +3,6 @@ use std::{collections::VecDeque, time::Duration}; use tokio::sync::Mutex; use super::ServerSession; -#[cfg(test)] use crate::bson::Document; #[derive(Debug)] @@ -68,4 +67,10 @@ impl ServerSessionPool { pub(crate) async fn contains(&self, id: &Document) -> bool { self.pool.lock().await.iter().any(|s| &s.id == id) } + + /// Returns a list of the IDs of the sessions contained in the pool. + pub(crate) async fn get_session_ids(&self) -> Vec { + let sessions = self.pool.lock().await; + sessions.iter().map(|session| session.id.clone()).collect() + } } diff --git a/src/gridfs/upload.rs b/src/gridfs/upload.rs index 404bf641d..8bd14f565 100644 --- a/src/gridfs/upload.rs +++ b/src/gridfs/upload.rs @@ -261,7 +261,6 @@ impl GridFsUploadStream { } impl Drop for GridFsUploadStream { - // TODO RUST-1493: pre-create this task fn drop(&mut self) { if !matches!(self.state, State::Closed) { let chunks = self.bucket.chunks().clone(); diff --git a/src/test/client.rs b/src/test/client.rs index eb62cf752..67f6269fc 100644 --- a/src/test/client.rs +++ b/src/test/client.rs @@ -15,7 +15,7 @@ use crate::{ get_client_options, log_uncaptured, util::{ - event_buffer::EventBuffer, + event_buffer::{EventBuffer, EventStream}, fail_point::{FailPoint, FailPointMode}, TestClient, }, @@ -930,3 +930,55 @@ async fn warm_connection_pool() { // Validate that a command executes. client.list_database_names().await.unwrap(); } + +async fn get_end_session_event_count<'a>(event_stream: &mut EventStream<'a, Event>) -> usize { + // Use collect_successful_command_execution to assert that the call to endSessions succeeded. + event_stream + .collect_successful_command_execution(Duration::from_millis(500), "endSessions") + .await + .len() +} + +#[tokio::test] +async fn end_sessions_on_drop() { + let client1 = Client::for_test().monitor_events().await; + let client2 = client1.clone(); + let events = client1.events.clone(); + let mut event_stream = events.stream(); + + // Run an operation to populate the session pool. + client1 + .database("db") + .collection::("coll") + .find(doc! {}) + .await + .unwrap(); + + drop(client1); + assert_eq!(get_end_session_event_count(&mut event_stream).await, 0); + + drop(client2); + assert_eq!(get_end_session_event_count(&mut event_stream).await, 1); +} + +#[tokio::test] +async fn end_sessions_on_shutdown() { + let client1 = Client::for_test().monitor_events().await; + let client2 = client1.clone(); + let events = client1.events.clone(); + let mut event_stream = events.stream(); + + // Run an operation to populate the session pool. + client1 + .database("db") + .collection::("coll") + .find(doc! {}) + .await + .unwrap(); + + client1.into_client().shutdown().await; + assert_eq!(get_end_session_event_count(&mut event_stream).await, 1); + + client2.into_client().shutdown().await; + assert_eq!(get_end_session_event_count(&mut event_stream).await, 0); +} diff --git a/src/test/spec/json/connection-monitoring-and-pooling/README.rst b/src/test/spec/json/connection-monitoring-and-pooling/README.rst deleted file mode 100644 index ae4af543f..000000000 --- a/src/test/spec/json/connection-monitoring-and-pooling/README.rst +++ /dev/null @@ -1,36 +0,0 @@ -.. role:: javascript(code) - :language: javascript - -======================================== -Connection Monitoring and Pooling (CMAP) -======================================== - -.. contents:: - --------- - -Introduction -============ -Drivers MUST implement all of the following types of CMAP tests: - -* Pool unit and integration tests as described in `cmap-format/README.rst <./cmap-format/README.rst>`__ -* Pool prose tests as described below in `Prose Tests`_ -* Logging tests as described below in `Logging Tests`_ - -Prose Tests -=========== - -The following tests have not yet been automated, but MUST still be tested: - -#. All ConnectionPoolOptions MUST be specified at the MongoClient level -#. All ConnectionPoolOptions MUST be the same for all pools created by a MongoClient -#. A user MUST be able to specify all ConnectionPoolOptions via a URI string -#. A user MUST be able to subscribe to Connection Monitoring Events in a manner idiomatic to their language and driver -#. When a check out attempt fails because connection set up throws an error, - assert that a ConnectionCheckOutFailedEvent with reason="connectionError" is emitted. - -Logging Tests -============= - -Tests for connection pool logging can be found in the `/logging <./logging>`__ subdirectory and are written in the -`Unified Test Format <../../unified-test-format/unified-test-format.rst>`__. \ No newline at end of file diff --git a/src/test/spec/json/connection-monitoring-and-pooling/logging/connection-logging.json b/src/test/spec/json/connection-monitoring-and-pooling/logging/connection-logging.json index 2f8e28307..72103b3ca 100644 --- a/src/test/spec/json/connection-monitoring-and-pooling/logging/connection-logging.json +++ b/src/test/spec/json/connection-monitoring-and-pooling/logging/connection-logging.json @@ -201,6 +201,73 @@ } } }, + { + "level": "debug", + "component": "connection", + "data": { + "message": "Connection checkout started", + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + }, + { + "level": "debug", + "component": "connection", + "data": { + "message": "Connection checked out", + "driverConnectionId": { + "$$type": [ + "int", + "long" + ] + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + }, + "durationMS": { + "$$type": [ + "double", + "int", + "long" + ] + } + } + }, + { + "level": "debug", + "component": "connection", + "data": { + "message": "Connection checked in", + "driverConnectionId": { + "$$type": [ + "int", + "long" + ] + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + }, { "level": "debug", "component": "connection", diff --git a/src/test/spec/json/connection-monitoring-and-pooling/logging/connection-logging.yml b/src/test/spec/json/connection-monitoring-and-pooling/logging/connection-logging.yml index 15cf0d6b1..49868a062 100644 --- a/src/test/spec/json/connection-monitoring-and-pooling/logging/connection-logging.yml +++ b/src/test/spec/json/connection-monitoring-and-pooling/logging/connection-logging.yml @@ -85,6 +85,31 @@ tests: serverHost: { $$type: string } serverPort: { $$type: [int, long] } + # The next three expected logs are for ending a session. + - level: debug + component: connection + data: + message: "Connection checkout started" + serverHost: { $$type: string } + serverPort: { $$type: [int, long] } + + - level: debug + component: connection + data: + message: "Connection checked out" + driverConnectionId: { $$type: [int, long] } + serverHost: { $$type: string } + serverPort: { $$type: [int, long] } + durationMS: { $$type: [double, int, long] } + + - level: debug + component: connection + data: + message: "Connection checked in" + driverConnectionId: { $$type: [int, long] } + serverHost: { $$type: string } + serverPort: { $$type: [int, long] } + - level: debug component: connection data: diff --git a/src/test/spec/unified_runner/operation.rs b/src/test/spec/unified_runner/operation.rs index 0292d3dfa..3687bbc03 100644 --- a/src/test/spec/unified_runner/operation.rs +++ b/src/test/spec/unified_runner/operation.rs @@ -2182,7 +2182,13 @@ impl TestOperation for Close { Entity::Client(_) => { let client = entities.get_mut(id).unwrap().as_mut_client(); let closed_client_topology_id = client.topology_id; - client.client = None; + client + .client + .take() + .unwrap() + .shutdown() + .immediate(true) + .await; let mut entities_to_remove = vec![]; for (key, value) in entities.iter() { diff --git a/src/test/util/event.rs b/src/test/util/event.rs index 00c2f8641..679172b7a 100644 --- a/src/test/util/event.rs +++ b/src/test/util/event.rs @@ -174,7 +174,6 @@ impl IntoFuture for EventClientBuilder { } impl EventClient { - #[allow(dead_code)] pub(crate) fn into_client(self) -> crate::Client { self.client.into_client() } diff --git a/src/test/util/event_buffer.rs b/src/test/util/event_buffer.rs index 2db230413..1e67b8825 100644 --- a/src/test/util/event_buffer.rs +++ b/src/test/util/event_buffer.rs @@ -420,4 +420,23 @@ impl<'a> EventStream<'a, Event> { .await .ok() } + + pub(crate) async fn collect_successful_command_execution( + &mut self, + timeout: Duration, + command_name: impl AsRef, + ) -> Vec<(CommandStartedEvent, CommandSucceededEvent)> { + let mut event_pairs = Vec::new(); + let command_name = command_name.as_ref(); + let _ = runtime::timeout(timeout, async { + while let Some(next_pair) = self + .next_successful_command_execution(timeout, command_name) + .await + { + event_pairs.push(next_pair); + } + }) + .await; + event_pairs + } } diff --git a/src/tracking_arc.rs b/src/tracking_arc.rs index 785f78d36..9382b2573 100644 --- a/src/tracking_arc.rs +++ b/src/tracking_arc.rs @@ -61,6 +61,10 @@ impl TrackingArc { Arc::ptr_eq(&this.inner, &other.inner) } + pub(crate) fn strong_count(this: &Self) -> usize { + Arc::strong_count(&this.inner) + } + #[cfg(all(test, mongodb_internal_tracking_arc))] #[allow(unused)] pub(crate) fn print_live(tracked: &Self) {