diff --git a/src/client/auth.rs b/src/client/auth.rs index 9624f9c56..d4c747b5c 100644 --- a/src/client/auth.rs +++ b/src/client/auth.rs @@ -16,7 +16,7 @@ mod x509; use std::{borrow::Cow, fmt::Debug, str::FromStr}; -use crate::{bson::RawDocumentBuf, bson_compat::cstr}; +use crate::{bson::RawDocumentBuf, bson_compat::cstr, options::ClientOptions}; use derive_where::derive_where; use hmac::{digest::KeyInit, Mac}; use rand::Rng; @@ -287,12 +287,11 @@ impl AuthMechanism { &self, stream: &mut Connection, credential: &Credential, - server_api: Option<&ServerApi>, - #[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient, - #[cfg(feature = "gssapi-auth")] resolver_config: Option<&ResolverConfig>, + opts: &AuthOptions, ) -> Result<()> { self.validate_credential(credential)?; + let server_api = opts.server_api.as_ref(); match self { AuthMechanism::ScramSha1 => { ScramVersion::Sha1 @@ -309,14 +308,20 @@ impl AuthMechanism { } #[cfg(feature = "gssapi-auth")] AuthMechanism::Gssapi => { - gssapi::authenticate_stream(stream, credential, server_api, resolver_config).await + gssapi::authenticate_stream( + stream, + credential, + server_api, + opts.resolver_config.as_ref(), + ) + .await } AuthMechanism::Plain => { plain::authenticate_stream(stream, credential, server_api).await } #[cfg(feature = "aws-auth")] AuthMechanism::MongoDbAws => { - aws::authenticate_stream(stream, credential, server_api, http_client).await + aws::authenticate_stream(stream, credential, server_api, &opts.http_client).await } AuthMechanism::MongoDbCr => Err(ErrorKind::Authentication { message: "MONGODB-CR is deprecated and not supported by this driver. Use SCRAM \ @@ -409,6 +414,28 @@ impl FromStr for AuthMechanism { } } +#[derive(Clone, Debug, Default)] +// Auxiliary information needed by authentication mechanisms. +pub(crate) struct AuthOptions { + server_api: Option, + #[cfg(feature = "aws-auth")] + http_client: crate::runtime::HttpClient, + #[cfg(feature = "gssapi-auth")] + resolver_config: Option, +} + +impl From<&ClientOptions> for AuthOptions { + fn from(opts: &ClientOptions) -> Self { + Self { + server_api: opts.server_api.clone(), + #[cfg(feature = "aws-auth")] + http_client: crate::runtime::HttpClient::default(), + #[cfg(feature = "gssapi-auth")] + resolver_config: opts.resolver_config.clone(), + } + } +} + /// A struct containing authentication information. /// /// Some fields (mechanism and source) may be omitted and will either be negotiated or assigned a @@ -495,10 +522,8 @@ impl Credential { pub(crate) async fn authenticate_stream( &self, conn: &mut Connection, - server_api: Option<&ServerApi>, first_round: Option, - #[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient, - #[cfg(feature = "gssapi-auth")] resolver_config: Option<&ResolverConfig>, + opts: &AuthOptions, ) -> Result<()> { let stream_description = conn.stream_description()?; @@ -510,6 +535,7 @@ impl Credential { // If speculative authentication returned a response, then short-circuit the authentication // logic and use the first round from the handshake. if let Some(first_round) = first_round { + let server_api = opts.server_api.as_ref(); return match first_round { FirstRound::Scram(version, first_round) => { version @@ -530,17 +556,7 @@ impl Credential { Some(ref m) => Cow::Borrowed(m), }; // Authenticate according to the chosen mechanism. - mechanism - .authenticate_stream( - conn, - self, - server_api, - #[cfg(feature = "aws-auth")] - http_client, - #[cfg(feature = "gssapi-auth")] - resolver_config, - ) - .await + mechanism.authenticate_stream(conn, self, opts).await } #[cfg(test)] diff --git a/src/cmap/establish.rs b/src/cmap/establish.rs index 9520ff13c..c8c1eee06 100644 --- a/src/cmap/establish.rs +++ b/src/cmap/establish.rs @@ -48,23 +48,10 @@ pub(crate) struct EstablisherOptions { pub(crate) test_patch_reply: Option)>, } -impl EstablisherOptions { - pub(crate) fn from_client_options(opts: &ClientOptions) -> Self { +impl From<&ClientOptions> for EstablisherOptions { + fn from(opts: &ClientOptions) -> Self { Self { - handshake_options: HandshakerOptions { - app_name: opts.app_name.clone(), - #[cfg(any( - feature = "zstd-compression", - feature = "zlib-compression", - feature = "snappy-compression" - ))] - compressors: opts.compressors.clone(), - driver_info: opts.driver_info.clone(), - server_api: opts.server_api.clone(), - load_balanced: opts.load_balanced.unwrap_or(false), - #[cfg(feature = "gssapi-auth")] - resolver_config: opts.resolver_config.clone(), - }, + handshake_options: HandshakerOptions::from(opts), tls_options: opts.tls_options(), connect_timeout: opts.connect_timeout, #[cfg(test)] diff --git a/src/cmap/establish/handshake.rs b/src/cmap/establish/handshake.rs index 53f118710..56e21a2c3 100644 --- a/src/cmap/establish/handshake.rs +++ b/src/cmap/establish/handshake.rs @@ -6,6 +6,7 @@ use std::env; use crate::{ bson::{rawdoc, RawBson, RawDocumentBuf}, bson_compat::cstr, + options::{AuthOptions, ClientOptions}, }; use once_cell::sync::Lazy; use tokio::sync::broadcast; @@ -16,8 +17,6 @@ use tokio::sync::broadcast; feature = "snappy-compression" ))] use crate::options::Compressor; -#[cfg(feature = "gssapi-auth")] -use crate::options::ResolverConfig; use crate::{ client::auth::ClientFirst, cmap::{Command, Connection, StreamDescription}, @@ -338,15 +337,9 @@ pub(crate) struct Handshaker { ))] compressors: Option>, - server_api: Option, - metadata: ClientMetadata, - #[cfg(feature = "aws-auth")] - http_client: crate::runtime::HttpClient, - - #[cfg(feature = "gssapi-auth")] - resolver_config: Option, + auth_options: AuthOptions, } #[cfg(test)] @@ -412,12 +405,8 @@ impl Handshaker { feature = "snappy-compression" ))] compressors: options.compressors, - server_api: options.server_api, metadata, - #[cfg(feature = "aws-auth")] - http_client: crate::runtime::HttpClient::default(), - #[cfg(feature = "gssapi-auth")] - resolver_config: options.resolver_config, + auth_options: options.auth_options, }) } @@ -499,15 +488,7 @@ impl Handshaker { if let Some(credential) = credential { credential - .authenticate_stream( - conn, - self.server_api.as_ref(), - first_round, - #[cfg(feature = "aws-auth")] - &self.http_client, - #[cfg(feature = "gssapi-auth")] - self.resolver_config.as_ref(), - ) + .authenticate_stream(conn, first_round, &self.auth_options) .await? } @@ -542,9 +523,26 @@ pub(crate) struct HandshakerOptions { /// Whether or not the client is connecting to a MongoDB cluster through a load balancer. pub(crate) load_balanced: bool, - /// Configuration of the DNS resolver used for hostname canonicalization for GSSAPI. - #[cfg(feature = "gssapi-auth")] - pub(crate) resolver_config: Option, + /// Auxiliary data for authentication mechanisms. + pub(crate) auth_options: AuthOptions, +} + +impl From<&ClientOptions> for HandshakerOptions { + fn from(opts: &ClientOptions) -> Self { + Self { + app_name: opts.app_name.clone(), + #[cfg(any( + feature = "zstd-compression", + feature = "zlib-compression", + feature = "snappy-compression" + ))] + compressors: opts.compressors.clone(), + driver_info: opts.driver_info.clone(), + server_api: opts.server_api.clone(), + load_balanced: opts.load_balanced.unwrap_or(false), + auth_options: AuthOptions::from(opts), + } + } } /// Updates the handshake command document with the speculative authentication info. diff --git a/src/cmap/establish/handshake/test.rs b/src/cmap/establish/handshake/test.rs index 846f38590..9816d0d90 100644 --- a/src/cmap/establish/handshake/test.rs +++ b/src/cmap/establish/handshake/test.rs @@ -1,6 +1,6 @@ use std::ops::Deref; -use crate::bson::rawdoc; +use crate::{bson::rawdoc, options::AuthOptions}; use super::Handshaker; use crate::{cmap::establish::handshake::HandshakerOptions, options::DriverInfo}; @@ -18,8 +18,7 @@ async fn metadata_no_options() { driver_info: None, server_api: None, load_balanced: false, - #[cfg(feature = "gssapi-auth")] - resolver_config: None, + auth_options: AuthOptions::default(), }) .unwrap(); @@ -68,8 +67,7 @@ async fn metadata_with_options() { compressors: None, server_api: None, load_balanced: false, - #[cfg(feature = "gssapi-auth")] - resolver_config: None, + auth_options: AuthOptions::default(), }; let handshaker = Handshaker::new(options).unwrap(); diff --git a/src/cmap/test.rs b/src/cmap/test.rs index 94d575c4d..7f50c6d23 100644 --- a/src/cmap/test.rs +++ b/src/cmap/test.rs @@ -162,10 +162,8 @@ impl Executor { let pool = ConnectionPool::new( get_client_options().await.hosts[0].clone(), - ConnectionEstablisher::new(EstablisherOptions::from_client_options( - get_client_options().await, - )) - .unwrap(), + ConnectionEstablisher::new(EstablisherOptions::from(get_client_options().await)) + .unwrap(), updater, crate::bson::oid::ObjectId::new(), Some(self.pool_options), diff --git a/src/cmap/test/integration.rs b/src/cmap/test/integration.rs index c4faeab65..94392305f 100644 --- a/src/cmap/test/integration.rs +++ b/src/cmap/test/integration.rs @@ -49,8 +49,7 @@ async fn acquire_connection_and_send_command() { let pool = ConnectionPool::new( client_options.hosts[0].clone(), - ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options)) - .unwrap(), + ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(), TopologyUpdater::channel().0, crate::bson::oid::ObjectId::new(), Some(pool_options), @@ -124,8 +123,7 @@ async fn concurrent_connections() { let pool = ConnectionPool::new( get_client_options().await.hosts[0].clone(), - ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options)) - .unwrap(), + ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(), TopologyUpdater::channel().0, crate::bson::oid::ObjectId::new(), Some(options), @@ -209,8 +207,7 @@ async fn connection_error_during_establishment() { options.cmap_event_handler = Some(buffer.handler()); let pool = ConnectionPool::new( client_options.hosts[0].clone(), - ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options)) - .unwrap(), + ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(), TopologyUpdater::channel().0, crate::bson::oid::ObjectId::new(), Some(options), diff --git a/src/sdam/topology.rs b/src/sdam/topology.rs index 7270d1daa..4de8cd61a 100644 --- a/src/sdam/topology.rs +++ b/src/sdam/topology.rs @@ -108,7 +108,7 @@ impl Topology { let (watcher, publisher) = TopologyWatcher::channel(state); let connection_establisher = - ConnectionEstablisher::new(EstablisherOptions::from_client_options(&options))?; + ConnectionEstablisher::new(EstablisherOptions::from(&options))?; let worker = TopologyWorker { id, diff --git a/src/test/spec/handshake.rs b/src/test/spec/handshake.rs index 11a071371..89d563f0b 100644 --- a/src/test/spec/handshake.rs +++ b/src/test/spec/handshake.rs @@ -15,7 +15,7 @@ use crate::{ #[tokio::test] async fn arbitrary_auth_mechanism() { let client_options = get_client_options().await; - let mut options = EstablisherOptions::from_client_options(client_options); + let mut options = EstablisherOptions::from(client_options); options.test_patch_reply = Some(|reply| { reply .as_mut()