Skip to content

Commit e56ca78

Browse files
authored
RUST-2247 Bundle extra arguments to auth mechanisms (#1428)
1 parent c205ff7 commit e56ca78

File tree

8 files changed

+73
-79
lines changed

8 files changed

+73
-79
lines changed

src/client/auth.rs

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ mod x509;
1616

1717
use std::{borrow::Cow, fmt::Debug, str::FromStr};
1818

19-
use crate::{bson::RawDocumentBuf, bson_compat::cstr};
19+
use crate::{bson::RawDocumentBuf, bson_compat::cstr, options::ClientOptions};
2020
use derive_where::derive_where;
2121
use hmac::{digest::KeyInit, Mac};
2222
use rand::Rng;
@@ -287,12 +287,11 @@ impl AuthMechanism {
287287
&self,
288288
stream: &mut Connection,
289289
credential: &Credential,
290-
server_api: Option<&ServerApi>,
291-
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
292-
#[cfg(feature = "gssapi-auth")] resolver_config: Option<&ResolverConfig>,
290+
opts: &AuthOptions,
293291
) -> Result<()> {
294292
self.validate_credential(credential)?;
295293

294+
let server_api = opts.server_api.as_ref();
296295
match self {
297296
AuthMechanism::ScramSha1 => {
298297
ScramVersion::Sha1
@@ -309,14 +308,20 @@ impl AuthMechanism {
309308
}
310309
#[cfg(feature = "gssapi-auth")]
311310
AuthMechanism::Gssapi => {
312-
gssapi::authenticate_stream(stream, credential, server_api, resolver_config).await
311+
gssapi::authenticate_stream(
312+
stream,
313+
credential,
314+
server_api,
315+
opts.resolver_config.as_ref(),
316+
)
317+
.await
313318
}
314319
AuthMechanism::Plain => {
315320
plain::authenticate_stream(stream, credential, server_api).await
316321
}
317322
#[cfg(feature = "aws-auth")]
318323
AuthMechanism::MongoDbAws => {
319-
aws::authenticate_stream(stream, credential, server_api, http_client).await
324+
aws::authenticate_stream(stream, credential, server_api, &opts.http_client).await
320325
}
321326
AuthMechanism::MongoDbCr => Err(ErrorKind::Authentication {
322327
message: "MONGODB-CR is deprecated and not supported by this driver. Use SCRAM \
@@ -409,6 +414,28 @@ impl FromStr for AuthMechanism {
409414
}
410415
}
411416

417+
#[derive(Clone, Debug, Default)]
418+
// Auxiliary information needed by authentication mechanisms.
419+
pub(crate) struct AuthOptions {
420+
server_api: Option<ServerApi>,
421+
#[cfg(feature = "aws-auth")]
422+
http_client: crate::runtime::HttpClient,
423+
#[cfg(feature = "gssapi-auth")]
424+
resolver_config: Option<ResolverConfig>,
425+
}
426+
427+
impl From<&ClientOptions> for AuthOptions {
428+
fn from(opts: &ClientOptions) -> Self {
429+
Self {
430+
server_api: opts.server_api.clone(),
431+
#[cfg(feature = "aws-auth")]
432+
http_client: crate::runtime::HttpClient::default(),
433+
#[cfg(feature = "gssapi-auth")]
434+
resolver_config: opts.resolver_config.clone(),
435+
}
436+
}
437+
}
438+
412439
/// A struct containing authentication information.
413440
///
414441
/// Some fields (mechanism and source) may be omitted and will either be negotiated or assigned a
@@ -495,10 +522,8 @@ impl Credential {
495522
pub(crate) async fn authenticate_stream(
496523
&self,
497524
conn: &mut Connection,
498-
server_api: Option<&ServerApi>,
499525
first_round: Option<FirstRound>,
500-
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
501-
#[cfg(feature = "gssapi-auth")] resolver_config: Option<&ResolverConfig>,
526+
opts: &AuthOptions,
502527
) -> Result<()> {
503528
let stream_description = conn.stream_description()?;
504529

@@ -510,6 +535,7 @@ impl Credential {
510535
// If speculative authentication returned a response, then short-circuit the authentication
511536
// logic and use the first round from the handshake.
512537
if let Some(first_round) = first_round {
538+
let server_api = opts.server_api.as_ref();
513539
return match first_round {
514540
FirstRound::Scram(version, first_round) => {
515541
version
@@ -530,17 +556,7 @@ impl Credential {
530556
Some(ref m) => Cow::Borrowed(m),
531557
};
532558
// Authenticate according to the chosen mechanism.
533-
mechanism
534-
.authenticate_stream(
535-
conn,
536-
self,
537-
server_api,
538-
#[cfg(feature = "aws-auth")]
539-
http_client,
540-
#[cfg(feature = "gssapi-auth")]
541-
resolver_config,
542-
)
543-
.await
559+
mechanism.authenticate_stream(conn, self, opts).await
544560
}
545561

546562
#[cfg(test)]

src/cmap/establish.rs

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,10 @@ pub(crate) struct EstablisherOptions {
4848
pub(crate) test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
4949
}
5050

51-
impl EstablisherOptions {
52-
pub(crate) fn from_client_options(opts: &ClientOptions) -> Self {
51+
impl From<&ClientOptions> for EstablisherOptions {
52+
fn from(opts: &ClientOptions) -> Self {
5353
Self {
54-
handshake_options: HandshakerOptions {
55-
app_name: opts.app_name.clone(),
56-
#[cfg(any(
57-
feature = "zstd-compression",
58-
feature = "zlib-compression",
59-
feature = "snappy-compression"
60-
))]
61-
compressors: opts.compressors.clone(),
62-
driver_info: opts.driver_info.clone(),
63-
server_api: opts.server_api.clone(),
64-
load_balanced: opts.load_balanced.unwrap_or(false),
65-
#[cfg(feature = "gssapi-auth")]
66-
resolver_config: opts.resolver_config.clone(),
67-
},
54+
handshake_options: HandshakerOptions::from(opts),
6855
tls_options: opts.tls_options(),
6956
connect_timeout: opts.connect_timeout,
7057
#[cfg(test)]

src/cmap/establish/handshake.rs

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::env;
66
use crate::{
77
bson::{rawdoc, RawBson, RawDocumentBuf},
88
bson_compat::cstr,
9+
options::{AuthOptions, ClientOptions},
910
};
1011
use once_cell::sync::Lazy;
1112
use tokio::sync::broadcast;
@@ -16,8 +17,6 @@ use tokio::sync::broadcast;
1617
feature = "snappy-compression"
1718
))]
1819
use crate::options::Compressor;
19-
#[cfg(feature = "gssapi-auth")]
20-
use crate::options::ResolverConfig;
2120
use crate::{
2221
client::auth::ClientFirst,
2322
cmap::{Command, Connection, StreamDescription},
@@ -338,15 +337,9 @@ pub(crate) struct Handshaker {
338337
))]
339338
compressors: Option<Vec<Compressor>>,
340339

341-
server_api: Option<ServerApi>,
342-
343340
metadata: ClientMetadata,
344341

345-
#[cfg(feature = "aws-auth")]
346-
http_client: crate::runtime::HttpClient,
347-
348-
#[cfg(feature = "gssapi-auth")]
349-
resolver_config: Option<ResolverConfig>,
342+
auth_options: AuthOptions,
350343
}
351344

352345
#[cfg(test)]
@@ -412,12 +405,8 @@ impl Handshaker {
412405
feature = "snappy-compression"
413406
))]
414407
compressors: options.compressors,
415-
server_api: options.server_api,
416408
metadata,
417-
#[cfg(feature = "aws-auth")]
418-
http_client: crate::runtime::HttpClient::default(),
419-
#[cfg(feature = "gssapi-auth")]
420-
resolver_config: options.resolver_config,
409+
auth_options: options.auth_options,
421410
})
422411
}
423412

@@ -499,15 +488,7 @@ impl Handshaker {
499488

500489
if let Some(credential) = credential {
501490
credential
502-
.authenticate_stream(
503-
conn,
504-
self.server_api.as_ref(),
505-
first_round,
506-
#[cfg(feature = "aws-auth")]
507-
&self.http_client,
508-
#[cfg(feature = "gssapi-auth")]
509-
self.resolver_config.as_ref(),
510-
)
491+
.authenticate_stream(conn, first_round, &self.auth_options)
511492
.await?
512493
}
513494

@@ -542,9 +523,26 @@ pub(crate) struct HandshakerOptions {
542523
/// Whether or not the client is connecting to a MongoDB cluster through a load balancer.
543524
pub(crate) load_balanced: bool,
544525

545-
/// Configuration of the DNS resolver used for hostname canonicalization for GSSAPI.
546-
#[cfg(feature = "gssapi-auth")]
547-
pub(crate) resolver_config: Option<ResolverConfig>,
526+
/// Auxiliary data for authentication mechanisms.
527+
pub(crate) auth_options: AuthOptions,
528+
}
529+
530+
impl From<&ClientOptions> for HandshakerOptions {
531+
fn from(opts: &ClientOptions) -> Self {
532+
Self {
533+
app_name: opts.app_name.clone(),
534+
#[cfg(any(
535+
feature = "zstd-compression",
536+
feature = "zlib-compression",
537+
feature = "snappy-compression"
538+
))]
539+
compressors: opts.compressors.clone(),
540+
driver_info: opts.driver_info.clone(),
541+
server_api: opts.server_api.clone(),
542+
load_balanced: opts.load_balanced.unwrap_or(false),
543+
auth_options: AuthOptions::from(opts),
544+
}
545+
}
548546
}
549547

550548
/// Updates the handshake command document with the speculative authentication info.

src/cmap/establish/handshake/test.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::ops::Deref;
22

3-
use crate::bson::rawdoc;
3+
use crate::{bson::rawdoc, options::AuthOptions};
44

55
use super::Handshaker;
66
use crate::{cmap::establish::handshake::HandshakerOptions, options::DriverInfo};
@@ -18,8 +18,7 @@ async fn metadata_no_options() {
1818
driver_info: None,
1919
server_api: None,
2020
load_balanced: false,
21-
#[cfg(feature = "gssapi-auth")]
22-
resolver_config: None,
21+
auth_options: AuthOptions::default(),
2322
})
2423
.unwrap();
2524

@@ -68,8 +67,7 @@ async fn metadata_with_options() {
6867
compressors: None,
6968
server_api: None,
7069
load_balanced: false,
71-
#[cfg(feature = "gssapi-auth")]
72-
resolver_config: None,
70+
auth_options: AuthOptions::default(),
7371
};
7472

7573
let handshaker = Handshaker::new(options).unwrap();

src/cmap/test.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,8 @@ impl Executor {
162162

163163
let pool = ConnectionPool::new(
164164
get_client_options().await.hosts[0].clone(),
165-
ConnectionEstablisher::new(EstablisherOptions::from_client_options(
166-
get_client_options().await,
167-
))
168-
.unwrap(),
165+
ConnectionEstablisher::new(EstablisherOptions::from(get_client_options().await))
166+
.unwrap(),
169167
updater,
170168
crate::bson::oid::ObjectId::new(),
171169
Some(self.pool_options),

src/cmap/test/integration.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ async fn acquire_connection_and_send_command() {
4949

5050
let pool = ConnectionPool::new(
5151
client_options.hosts[0].clone(),
52-
ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options))
53-
.unwrap(),
52+
ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(),
5453
TopologyUpdater::channel().0,
5554
crate::bson::oid::ObjectId::new(),
5655
Some(pool_options),
@@ -124,8 +123,7 @@ async fn concurrent_connections() {
124123

125124
let pool = ConnectionPool::new(
126125
get_client_options().await.hosts[0].clone(),
127-
ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options))
128-
.unwrap(),
126+
ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(),
129127
TopologyUpdater::channel().0,
130128
crate::bson::oid::ObjectId::new(),
131129
Some(options),
@@ -209,8 +207,7 @@ async fn connection_error_during_establishment() {
209207
options.cmap_event_handler = Some(buffer.handler());
210208
let pool = ConnectionPool::new(
211209
client_options.hosts[0].clone(),
212-
ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options))
213-
.unwrap(),
210+
ConnectionEstablisher::new(EstablisherOptions::from(&client_options)).unwrap(),
214211
TopologyUpdater::channel().0,
215212
crate::bson::oid::ObjectId::new(),
216213
Some(options),

src/sdam/topology.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ impl Topology {
108108
let (watcher, publisher) = TopologyWatcher::channel(state);
109109

110110
let connection_establisher =
111-
ConnectionEstablisher::new(EstablisherOptions::from_client_options(&options))?;
111+
ConnectionEstablisher::new(EstablisherOptions::from(&options))?;
112112

113113
let worker = TopologyWorker {
114114
id,

src/test/spec/handshake.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::{
1515
#[tokio::test]
1616
async fn arbitrary_auth_mechanism() {
1717
let client_options = get_client_options().await;
18-
let mut options = EstablisherOptions::from_client_options(client_options);
18+
let mut options = EstablisherOptions::from(client_options);
1919
options.test_patch_reply = Some(|reply| {
2020
reply
2121
.as_mut()

0 commit comments

Comments
 (0)