Skip to content

Add support to specify pre-resolved IP addresses and avoid additional DNS lookup #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use std::{
ErrorKind::{BrokenPipe, NotConnected, Other},
},
mem::replace,
net::SocketAddr,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
Expand Down Expand Up @@ -357,9 +358,20 @@ impl Stream {
keepalive: Option<Duration>,
) -> io::Result<Stream> {
let tcp_stream = match addr {
HostPortOrUrl::HostPort(host, port) => {
TcpStream::connect((host.as_str(), *port)).await?
}
HostPortOrUrl::HostPort {
host,
port,
resolved_ips,
} => match resolved_ips {
Some(ips) => {
let addrs = ips
.iter()
.map(|ip| SocketAddr::new(*ip, *port))
.collect::<Vec<_>>();
TcpStream::connect(&*addrs).await?
}
None => TcpStream::connect((host.as_str(), *port)).await?,
},
HostPortOrUrl::Url(url) => {
let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?;
TcpStream::connect(&*addrs).await?
Expand Down
102 changes: 93 additions & 9 deletions src/opts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use url::{Host, Url};
use std::{
borrow::Cow,
fmt, io,
net::{Ipv4Addr, Ipv6Addr},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
Expand Down Expand Up @@ -66,37 +66,61 @@ pub const DEFAULT_TTL_CHECK_INTERVAL: Duration = Duration::from_secs(30);
/// into socket addresses using to_socket_addrs.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) enum HostPortOrUrl {
HostPort(String, u16),
HostPort {
host: String,
port: u16,
/// The resolved IP addresses to use for the TCP connection. If empty,
/// DNS resolution of `host` will be performed.
resolved_ips: Option<Vec<IpAddr>>,
},
Url(Url),
}

impl Default for HostPortOrUrl {
fn default() -> Self {
HostPortOrUrl::HostPort("127.0.0.1".to_string(), DEFAULT_PORT)
HostPortOrUrl::HostPort {
host: "127.0.0.1".to_string(),
port: DEFAULT_PORT,
resolved_ips: None,
}
}
}

impl HostPortOrUrl {
pub fn get_ip_or_hostname(&self) -> &str {
match self {
Self::HostPort(host, _) => host,
Self::HostPort { host, .. } => host,
Self::Url(url) => url.host_str().unwrap_or("127.0.0.1"),
}
}

pub fn get_tcp_port(&self) -> u16 {
match self {
Self::HostPort(_, port) => *port,
Self::HostPort { port, .. } => *port,
Self::Url(url) => url.port().unwrap_or(DEFAULT_PORT),
}
}

pub fn get_resolved_ips(&self) -> &Option<Vec<IpAddr>> {
match self {
Self::HostPort { resolved_ips, .. } => resolved_ips,
Self::Url(_) => &None,
}
}

pub fn is_loopback(&self) -> bool {
match self {
Self::HostPort(host, _) => {
Self::HostPort {
host, resolved_ips, ..
} => {
let v4addr: Option<Ipv4Addr> = FromStr::from_str(host).ok();
let v6addr: Option<Ipv6Addr> = FromStr::from_str(host).ok();
if let Some(addr) = v4addr {
if resolved_ips
.as_ref()
.is_some_and(|s| s.iter().any(|ip| ip.is_loopback()))
{
true
} else if let Some(addr) = v4addr {
addr.is_loopback()
} else if let Some(addr) = v6addr {
addr.is_loopback()
Expand Down Expand Up @@ -644,6 +668,11 @@ impl Opts {
self.inner.address.get_tcp_port()
}

/// The resolved IPs for the mysql server, if provided.
pub fn resolved_ips(&self) -> &Option<Vec<IpAddr>> {
self.inner.address.get_resolved_ips()
}

/// User (defaults to `None`).
///
/// # Connection URL
Expand Down Expand Up @@ -1139,6 +1168,7 @@ pub struct OptsBuilder {
opts: MysqlOpts,
ip_or_hostname: String,
tcp_port: u16,
resolved_ips: Option<Vec<IpAddr>>,
}

impl Default for OptsBuilder {
Expand All @@ -1148,6 +1178,7 @@ impl Default for OptsBuilder {
opts: MysqlOpts::default(),
ip_or_hostname: address.get_ip_or_hostname().into(),
tcp_port: address.get_tcp_port(),
resolved_ips: None,
}
}
}
Expand All @@ -1168,6 +1199,7 @@ impl OptsBuilder {
OptsBuilder {
tcp_port: opts.inner.address.get_tcp_port(),
ip_or_hostname: opts.inner.address.get_ip_or_hostname().to_string(),
resolved_ips: opts.inner.address.get_resolved_ips().clone(),
opts: opts.inner.mysql_opts.clone(),
}
}
Expand All @@ -1184,6 +1216,14 @@ impl OptsBuilder {
self
}

/// Defines already-resolved IPs to use for the connection. When provided
/// the connection will not perform DNS resolution and the hostname will be
/// used only for TLS identity verification purposes.
pub fn resolved_ips<T: Into<Vec<IpAddr>>>(mut self, ips: Option<T>) -> Self {
self.resolved_ips = ips.map(Into::into);
self
}

/// Defines user name. See [`Opts::user`].
pub fn user<T: Into<String>>(mut self, user: Option<T>) -> Self {
self.opts.user = user.map(Into::into);
Expand Down Expand Up @@ -1349,7 +1389,11 @@ impl OptsBuilder {

impl From<OptsBuilder> for Opts {
fn from(builder: OptsBuilder) -> Opts {
let address = HostPortOrUrl::HostPort(builder.ip_or_hostname, builder.tcp_port);
let address = HostPortOrUrl::HostPort {
host: builder.ip_or_hostname,
port: builder.tcp_port,
resolved_ips: builder.resolved_ips,
};
let inner_opts = InnerOpts {
mysql_opts: builder.opts,
address,
Expand Down Expand Up @@ -1838,7 +1882,7 @@ mod test {
use super::{HostPortOrUrl, MysqlOpts, Opts, Url};
use crate::{error::UrlError::InvalidParamValue, SslOpts};

use std::str::FromStr;
use std::{net::IpAddr, net::Ipv4Addr, net::Ipv6Addr, str::FromStr};

#[test]
fn test_builder_eq_url() {
Expand Down Expand Up @@ -2019,4 +2063,44 @@ mod test {
let url_opts = super::Opts::from_str(url).unwrap();
assert_eq!(url_opts.db_name(), builder_opts.db_name());
}

#[test]
fn test_builder_update_port_host_resolved_ips() {
let builder = super::OptsBuilder::default()
.ip_or_hostname("foo")
.tcp_port(33306);

let resolved = vec![
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 7)),
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff)),
];
let builder2 = builder
.clone()
.tcp_port(55223)
.resolved_ips(Some(resolved.clone()));

let builder_opts = Opts::from(builder);
assert_eq!(builder_opts.ip_or_hostname(), "foo");
assert_eq!(builder_opts.tcp_port(), 33306);
assert_eq!(
builder_opts.hostport_or_url(),
&HostPortOrUrl::HostPort {
host: "foo".to_string(),
port: 33306,
resolved_ips: None
}
);

let builder_opts2 = Opts::from(builder2);
assert_eq!(builder_opts2.ip_or_hostname(), "foo");
assert_eq!(builder_opts2.tcp_port(), 55223);
assert_eq!(
builder_opts2.hostport_or_url(),
&HostPortOrUrl::HostPort {
host: "foo".to_string(),
port: 55223,
resolved_ips: Some(resolved),
}
);
}
}