diff --git a/src/io/mod.rs b/src/io/mod.rs index 273acf1e..9a879ec6 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -31,6 +31,7 @@ use std::{ ErrorKind::{BrokenPipe, NotConnected, Other}, }, mem::replace, + net::SocketAddr, ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, @@ -357,9 +358,20 @@ impl Stream { keepalive: Option, ) -> io::Result { let tcp_stream = match addr { - HostPortOrUrl::HostPort(host, port) => { - TcpStream::connect((host.as_str(), *port)).await? - } + HostPortOrUrl::HostPort { + host, + port, + resolved_ips, + } => match resolved_ips { + Some(ips) => { + let addrs = ips + .iter() + .map(|ip| SocketAddr::new(*ip, *port)) + .collect::>(); + TcpStream::connect(&*addrs).await? + } + None => TcpStream::connect((host.as_str(), *port)).await?, + }, HostPortOrUrl::Url(url) => { let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?; TcpStream::connect(&*addrs).await? diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 1f62e136..e9044450 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -22,7 +22,7 @@ use url::{Host, Url}; use std::{ borrow::Cow, fmt, io, - net::{Ipv4Addr, Ipv6Addr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, str::FromStr, sync::Arc, @@ -66,37 +66,61 @@ pub const DEFAULT_TTL_CHECK_INTERVAL: Duration = Duration::from_secs(30); /// into socket addresses using to_socket_addrs. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) enum HostPortOrUrl { - HostPort(String, u16), + HostPort { + host: String, + port: u16, + /// The resolved IP addresses to use for the TCP connection. If empty, + /// DNS resolution of `host` will be performed. + resolved_ips: Option>, + }, Url(Url), } impl Default for HostPortOrUrl { fn default() -> Self { - HostPortOrUrl::HostPort("127.0.0.1".to_string(), DEFAULT_PORT) + HostPortOrUrl::HostPort { + host: "127.0.0.1".to_string(), + port: DEFAULT_PORT, + resolved_ips: None, + } } } impl HostPortOrUrl { pub fn get_ip_or_hostname(&self) -> &str { match self { - Self::HostPort(host, _) => host, + Self::HostPort { host, .. } => host, Self::Url(url) => url.host_str().unwrap_or("127.0.0.1"), } } pub fn get_tcp_port(&self) -> u16 { match self { - Self::HostPort(_, port) => *port, + Self::HostPort { port, .. } => *port, Self::Url(url) => url.port().unwrap_or(DEFAULT_PORT), } } + pub fn get_resolved_ips(&self) -> &Option> { + match self { + Self::HostPort { resolved_ips, .. } => resolved_ips, + Self::Url(_) => &None, + } + } + pub fn is_loopback(&self) -> bool { match self { - Self::HostPort(host, _) => { + Self::HostPort { + host, resolved_ips, .. + } => { let v4addr: Option = FromStr::from_str(host).ok(); let v6addr: Option = FromStr::from_str(host).ok(); - if let Some(addr) = v4addr { + if resolved_ips + .as_ref() + .is_some_and(|s| s.iter().any(|ip| ip.is_loopback())) + { + true + } else if let Some(addr) = v4addr { addr.is_loopback() } else if let Some(addr) = v6addr { addr.is_loopback() @@ -644,6 +668,11 @@ impl Opts { self.inner.address.get_tcp_port() } + /// The resolved IPs for the mysql server, if provided. + pub fn resolved_ips(&self) -> &Option> { + self.inner.address.get_resolved_ips() + } + /// User (defaults to `None`). /// /// # Connection URL @@ -1139,6 +1168,7 @@ pub struct OptsBuilder { opts: MysqlOpts, ip_or_hostname: String, tcp_port: u16, + resolved_ips: Option>, } impl Default for OptsBuilder { @@ -1148,6 +1178,7 @@ impl Default for OptsBuilder { opts: MysqlOpts::default(), ip_or_hostname: address.get_ip_or_hostname().into(), tcp_port: address.get_tcp_port(), + resolved_ips: None, } } } @@ -1168,6 +1199,7 @@ impl OptsBuilder { OptsBuilder { tcp_port: opts.inner.address.get_tcp_port(), ip_or_hostname: opts.inner.address.get_ip_or_hostname().to_string(), + resolved_ips: opts.inner.address.get_resolved_ips().clone(), opts: opts.inner.mysql_opts.clone(), } } @@ -1184,6 +1216,14 @@ impl OptsBuilder { self } + /// Defines already-resolved IPs to use for the connection. When provided + /// the connection will not perform DNS resolution and the hostname will be + /// used only for TLS identity verification purposes. + pub fn resolved_ips>>(mut self, ips: Option) -> Self { + self.resolved_ips = ips.map(Into::into); + self + } + /// Defines user name. See [`Opts::user`]. pub fn user>(mut self, user: Option) -> Self { self.opts.user = user.map(Into::into); @@ -1349,7 +1389,11 @@ impl OptsBuilder { impl From for Opts { fn from(builder: OptsBuilder) -> Opts { - let address = HostPortOrUrl::HostPort(builder.ip_or_hostname, builder.tcp_port); + let address = HostPortOrUrl::HostPort { + host: builder.ip_or_hostname, + port: builder.tcp_port, + resolved_ips: builder.resolved_ips, + }; let inner_opts = InnerOpts { mysql_opts: builder.opts, address, @@ -1838,7 +1882,7 @@ mod test { use super::{HostPortOrUrl, MysqlOpts, Opts, Url}; use crate::{error::UrlError::InvalidParamValue, SslOpts}; - use std::str::FromStr; + use std::{net::IpAddr, net::Ipv4Addr, net::Ipv6Addr, str::FromStr}; #[test] fn test_builder_eq_url() { @@ -2019,4 +2063,44 @@ mod test { let url_opts = super::Opts::from_str(url).unwrap(); assert_eq!(url_opts.db_name(), builder_opts.db_name()); } + + #[test] + fn test_builder_update_port_host_resolved_ips() { + let builder = super::OptsBuilder::default() + .ip_or_hostname("foo") + .tcp_port(33306); + + let resolved = vec![ + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 7)), + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff)), + ]; + let builder2 = builder + .clone() + .tcp_port(55223) + .resolved_ips(Some(resolved.clone())); + + let builder_opts = Opts::from(builder); + assert_eq!(builder_opts.ip_or_hostname(), "foo"); + assert_eq!(builder_opts.tcp_port(), 33306); + assert_eq!( + builder_opts.hostport_or_url(), + &HostPortOrUrl::HostPort { + host: "foo".to_string(), + port: 33306, + resolved_ips: None + } + ); + + let builder_opts2 = Opts::from(builder2); + assert_eq!(builder_opts2.ip_or_hostname(), "foo"); + assert_eq!(builder_opts2.tcp_port(), 55223); + assert_eq!( + builder_opts2.hostport_or_url(), + &HostPortOrUrl::HostPort { + host: "foo".to_string(), + port: 55223, + resolved_ips: Some(resolved), + } + ); + } }