Skip to content

RUST-2247 Bundle extra arguments to auth mechanisms #1428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions src/client/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod x509;

use std::{borrow::Cow, fmt::Debug, str::FromStr};

use crate::{bson::RawDocumentBuf, bson_compat::cstr};
use crate::{bson::RawDocumentBuf, bson_compat::cstr, options::ClientOptions};
use derive_where::derive_where;
use hmac::{digest::KeyInit, Mac};
use rand::Rng;
Expand Down Expand Up @@ -287,12 +287,11 @@ impl AuthMechanism {
&self,
stream: &mut Connection,
credential: &Credential,
server_api: Option<&ServerApi>,
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
#[cfg(feature = "gssapi-auth")] resolver_config: Option<&ResolverConfig>,
opts: &AuthOptions,
) -> Result<()> {
self.validate_credential(credential)?;

let server_api = opts.server_api.as_ref();
match self {
AuthMechanism::ScramSha1 => {
ScramVersion::Sha1
Expand All @@ -309,14 +308,20 @@ impl AuthMechanism {
}
#[cfg(feature = "gssapi-auth")]
AuthMechanism::Gssapi => {
gssapi::authenticate_stream(stream, credential, server_api, resolver_config).await
gssapi::authenticate_stream(
stream,
credential,
server_api,
opts.resolver_config.as_ref(),
)
.await
}
AuthMechanism::Plain => {
plain::authenticate_stream(stream, credential, server_api).await
}
#[cfg(feature = "aws-auth")]
AuthMechanism::MongoDbAws => {
aws::authenticate_stream(stream, credential, server_api, http_client).await
aws::authenticate_stream(stream, credential, server_api, &opts.http_client).await
}
AuthMechanism::MongoDbCr => Err(ErrorKind::Authentication {
message: "MONGODB-CR is deprecated and not supported by this driver. Use SCRAM \
Expand Down Expand Up @@ -409,6 +414,28 @@ impl FromStr for AuthMechanism {
}
}

#[derive(Clone, Debug, Default)]
// Auxiliary information needed by authentication mechanisms.
pub(crate) struct AuthOptions {
server_api: Option<ServerApi>,
#[cfg(feature = "aws-auth")]
http_client: crate::runtime::HttpClient,
#[cfg(feature = "gssapi-auth")]
resolver_config: Option<ResolverConfig>,
}

impl From<&ClientOptions> for AuthOptions {
fn from(opts: &ClientOptions) -> Self {
Self {
server_api: opts.server_api.clone(),
#[cfg(feature = "aws-auth")]
http_client: crate::runtime::HttpClient::default(),
#[cfg(feature = "gssapi-auth")]
resolver_config: opts.resolver_config.clone(),
}
}
}

/// A struct containing authentication information.
///
/// Some fields (mechanism and source) may be omitted and will either be negotiated or assigned a
Expand Down Expand Up @@ -495,10 +522,8 @@ impl Credential {
pub(crate) async fn authenticate_stream(
&self,
conn: &mut Connection,
server_api: Option<&ServerApi>,
first_round: Option<FirstRound>,
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
#[cfg(feature = "gssapi-auth")] resolver_config: Option<&ResolverConfig>,
opts: &AuthOptions,
) -> Result<()> {
let stream_description = conn.stream_description()?;

Expand All @@ -510,6 +535,7 @@ impl Credential {
// If speculative authentication returned a response, then short-circuit the authentication
// logic and use the first round from the handshake.
if let Some(first_round) = first_round {
let server_api = opts.server_api.as_ref();
return match first_round {
FirstRound::Scram(version, first_round) => {
version
Expand All @@ -530,17 +556,7 @@ impl Credential {
Some(ref m) => Cow::Borrowed(m),
};
// Authenticate according to the chosen mechanism.
mechanism
.authenticate_stream(
conn,
self,
server_api,
#[cfg(feature = "aws-auth")]
http_client,
#[cfg(feature = "gssapi-auth")]
resolver_config,
)
.await
mechanism.authenticate_stream(conn, self, opts).await
}

#[cfg(test)]
Expand Down
19 changes: 3 additions & 16 deletions src/cmap/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,10 @@ pub(crate) struct EstablisherOptions {
pub(crate) test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
}

impl EstablisherOptions {
pub(crate) fn from_client_options(opts: &ClientOptions) -> Self {
impl From<&ClientOptions> for EstablisherOptions {
fn from(opts: &ClientOptions) -> Self {
Self {
handshake_options: HandshakerOptions {
app_name: opts.app_name.clone(),
#[cfg(any(
feature = "zstd-compression",
feature = "zlib-compression",
feature = "snappy-compression"
))]
compressors: opts.compressors.clone(),
driver_info: opts.driver_info.clone(),
server_api: opts.server_api.clone(),
load_balanced: opts.load_balanced.unwrap_or(false),
#[cfg(feature = "gssapi-auth")]
resolver_config: opts.resolver_config.clone(),
},
handshake_options: HandshakerOptions::from(opts),
tls_options: opts.tls_options(),
connect_timeout: opts.connect_timeout,
#[cfg(test)]
Expand Down
50 changes: 24 additions & 26 deletions src/cmap/establish/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::env;
use crate::{
bson::{rawdoc, RawBson, RawDocumentBuf},
bson_compat::cstr,
options::{AuthOptions, ClientOptions},
};
use once_cell::sync::Lazy;
use tokio::sync::broadcast;
Expand All @@ -16,8 +17,6 @@ use tokio::sync::broadcast;
feature = "snappy-compression"
))]
use crate::options::Compressor;
#[cfg(feature = "gssapi-auth")]
use crate::options::ResolverConfig;
use crate::{
client::auth::ClientFirst,
cmap::{Command, Connection, StreamDescription},
Expand Down Expand Up @@ -338,15 +337,9 @@ pub(crate) struct Handshaker {
))]
compressors: Option<Vec<Compressor>>,

server_api: Option<ServerApi>,

metadata: ClientMetadata,

#[cfg(feature = "aws-auth")]
http_client: crate::runtime::HttpClient,

#[cfg(feature = "gssapi-auth")]
resolver_config: Option<ResolverConfig>,
auth_options: AuthOptions,
}

#[cfg(test)]
Expand Down Expand Up @@ -412,12 +405,8 @@ impl Handshaker {
feature = "snappy-compression"
))]
compressors: options.compressors,
server_api: options.server_api,
metadata,
#[cfg(feature = "aws-auth")]
http_client: crate::runtime::HttpClient::default(),
#[cfg(feature = "gssapi-auth")]
resolver_config: options.resolver_config,
auth_options: options.auth_options,
})
}

Expand Down Expand Up @@ -499,15 +488,7 @@ impl Handshaker {

if let Some(credential) = credential {
credential
.authenticate_stream(
conn,
self.server_api.as_ref(),
first_round,
#[cfg(feature = "aws-auth")]
&self.http_client,
#[cfg(feature = "gssapi-auth")]
self.resolver_config.as_ref(),
)
.authenticate_stream(conn, first_round, &self.auth_options)
.await?
}

Expand Down Expand Up @@ -542,9 +523,26 @@ pub(crate) struct HandshakerOptions {
/// Whether or not the client is connecting to a MongoDB cluster through a load balancer.
pub(crate) load_balanced: bool,

/// Configuration of the DNS resolver used for hostname canonicalization for GSSAPI.
#[cfg(feature = "gssapi-auth")]
pub(crate) resolver_config: Option<ResolverConfig>,
/// Auxiliary data for authentication mechanisms.
pub(crate) auth_options: AuthOptions,
}

impl From<&ClientOptions> for HandshakerOptions {
fn from(opts: &ClientOptions) -> Self {
Self {
app_name: opts.app_name.clone(),
#[cfg(any(
feature = "zstd-compression",
feature = "zlib-compression",
feature = "snappy-compression"
))]
compressors: opts.compressors.clone(),
driver_info: opts.driver_info.clone(),
server_api: opts.server_api.clone(),
load_balanced: opts.load_balanced.unwrap_or(false),
auth_options: AuthOptions::from(opts),
}
}
}

/// Updates the handshake command document with the speculative authentication info.
Expand Down
8 changes: 3 additions & 5 deletions src/cmap/establish/handshake/test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::ops::Deref;

use crate::bson::rawdoc;
use crate::{bson::rawdoc, options::AuthOptions};

use super::Handshaker;
use crate::{cmap::establish::handshake::HandshakerOptions, options::DriverInfo};
Expand All @@ -18,8 +18,7 @@ async fn metadata_no_options() {
driver_info: None,
server_api: None,
load_balanced: false,
#[cfg(feature = "gssapi-auth")]
resolver_config: None,
auth_options: AuthOptions::default(),
})
.unwrap();

Expand Down Expand Up @@ -68,8 +67,7 @@ async fn metadata_with_options() {
compressors: None,
server_api: None,
load_balanced: false,
#[cfg(feature = "gssapi-auth")]
resolver_config: None,
auth_options: AuthOptions::default(),
};

let handshaker = Handshaker::new(options).unwrap();
Expand Down
6 changes: 2 additions & 4 deletions src/cmap/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,8 @@ impl Executor {

let pool = ConnectionPool::new(
get_client_options().await.hosts[0].clone(),
ConnectionEstablisher::new(EstablisherOptions::from_client_options(
get_client_options().await,
))
.unwrap(),
ConnectionEstablisher::new(EstablisherOptions::from(get_client_options().await))
.unwrap(),
updater,
crate::bson::oid::ObjectId::new(),
Some(self.pool_options),
Expand Down
9 changes: 3 additions & 6 deletions src/cmap/test/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ async fn acquire_connection_and_send_command() {

let pool = ConnectionPool::new(
client_options.hosts[0].clone(),
ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options))
.unwrap(),
ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(),
TopologyUpdater::channel().0,
crate::bson::oid::ObjectId::new(),
Some(pool_options),
Expand Down Expand Up @@ -124,8 +123,7 @@ async fn concurrent_connections() {

let pool = ConnectionPool::new(
get_client_options().await.hosts[0].clone(),
ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options))
.unwrap(),
ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(),
TopologyUpdater::channel().0,
crate::bson::oid::ObjectId::new(),
Some(options),
Expand Down Expand Up @@ -209,8 +207,7 @@ async fn connection_error_during_establishment() {
options.cmap_event_handler = Some(buffer.handler());
let pool = ConnectionPool::new(
client_options.hosts[0].clone(),
ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options))
.unwrap(),
ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(),
TopologyUpdater::channel().0,
crate::bson::oid::ObjectId::new(),
Some(options),
Expand Down
2 changes: 1 addition & 1 deletion src/sdam/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl Topology {
let (watcher, publisher) = TopologyWatcher::channel(state);

let connection_establisher =
ConnectionEstablisher::new(EstablisherOptions::from_client_options(&options))?;
ConnectionEstablisher::new(EstablisherOptions::from(&options))?;

let worker = TopologyWorker {
id,
Expand Down
2 changes: 1 addition & 1 deletion src/test/spec/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
#[tokio::test]
async fn arbitrary_auth_mechanism() {
let client_options = get_client_options().await;
let mut options = EstablisherOptions::from_client_options(client_options);
let mut options = EstablisherOptions::from(client_options);
options.test_patch_reply = Some(|reply| {
reply
.as_mut()
Expand Down