Skip to content

Change signature of get_many_mut APIs #562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 103 additions & 51 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1467,8 +1467,11 @@ where
/// Attempts to get mutable references to `N` values in the map at once.
///
/// Returns an array of length `N` with the results of each query. For soundness, at most one
/// mutable reference will be returned to any value. `None` will be returned if any of the
/// keys are duplicates or missing.
/// mutable reference will be returned to any value. `None` will be used if the key is missing.
///
/// # Panics
///
/// Panics if any keys are overlapping.
///
/// # Examples
///
Expand All @@ -1481,33 +1484,52 @@ where
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
/// libraries.insert("Library of Congress".to_string(), 1800);
///
/// // Get Athenæum and Bodleian Library
/// let [Some(a), Some(b)] = libraries.get_many_mut([
/// "Athenæum",
/// "Bodleian Library",
/// ]) else { panic!() };
///
/// // Assert values of Athenæum and Library of Congress
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "Library of Congress",
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// &mut 1807,
/// &mut 1800,
/// ]),
/// [
/// Some(&mut 1807),
/// Some(&mut 1800),
/// ],
/// );
///
/// // Missing keys result in None
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "New York Public Library",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(
/// got,
/// [
/// Some(&mut 1807),
/// None
/// ]
/// );
/// ```
///
/// ```should_panic
/// use hashbrown::HashMap;
///
/// // Duplicate keys result in None
/// let mut libraries = HashMap::new();
/// libraries.insert("Athenæum".to_string(), 1807);
///
/// // Duplicate keys panic!
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "Athenæum",
/// ]);
/// assert_eq!(got, None);
/// ```
pub fn get_many_mut<Q, const N: usize>(&mut self, ks: [&Q; N]) -> Option<[&'_ mut V; N]>
pub fn get_many_mut<Q, const N: usize>(&mut self, ks: [&Q; N]) -> [Option<&'_ mut V>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1517,8 +1539,8 @@ where
/// Attempts to get mutable references to `N` values in the map at once, without validating that
/// the values are unique.
///
/// Returns an array of length `N` with the results of each query. `None` will be returned if
/// any of the keys are missing.
/// Returns an array of length `N` with the results of each query. `None` will be used if
/// the key is missing.
///
/// For a safe alternative see [`get_many_mut`](`HashMap::get_many_mut`).
///
Expand All @@ -1540,29 +1562,37 @@ where
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
/// libraries.insert("Library of Congress".to_string(), 1800);
///
/// let got = libraries.get_many_mut([
/// // SAFETY: The keys do not overlap.
/// let [Some(a), Some(b)] = (unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "Bodleian Library",
/// ]) }) else { panic!() };
///
/// // SAFETY: The keys do not overlap.
/// let got = unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "Library of Congress",
/// ]);
/// ]) };
/// assert_eq!(
/// got,
/// Some([
/// &mut 1807,
/// &mut 1800,
/// ]),
/// [
/// Some(&mut 1807),
/// Some(&mut 1800),
/// ],
/// );
///
/// // Missing keys result in None
/// let got = libraries.get_many_mut([
/// // SAFETY: The keys do not overlap.
/// let got = unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "New York Public Library",
/// ]);
/// assert_eq!(got, None);
/// ]) };
/// // Missing keys result in None
/// assert_eq!(got, [Some(&mut 1807), None]);
/// ```
pub unsafe fn get_many_unchecked_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[&'_ mut V; N]>
) -> [Option<&'_ mut V>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1574,8 +1604,11 @@ where
/// references to the corresponding keys.
///
/// Returns an array of length `N` with the results of each query. For soundness, at most one
/// mutable reference will be returned to any value. `None` will be returned if any of the keys
/// are duplicates or missing.
/// mutable reference will be returned to any value. `None` will be used if the key is missing.
///
/// # Panics
///
/// Panics if any keys are overlapping.
///
/// # Examples
///
Expand All @@ -1594,30 +1627,37 @@ where
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// (&"Bodleian Library".to_string(), &mut 1602),
/// (&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691),
/// ]),
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// Some((&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691)),
/// ],
/// );
/// // Missing keys result in None
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Gewandhaus",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(got, [Some((&"Bodleian Library".to_string(), &mut 1602)), None]);
/// ```
///
/// ```should_panic
/// use hashbrown::HashMap;
///
/// let mut libraries = HashMap::new();
/// libraries.insert("Bodleian Library".to_string(), 1602);
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
///
/// // Duplicate keys result in None
/// // Duplicate keys result in panic!
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Herzogin-Anna-Amalia-Bibliothek",
/// "Herzogin-Anna-Amalia-Bibliothek",
/// ]);
/// assert_eq!(got, None);
/// ```
pub fn get_many_key_value_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[(&'_ K, &'_ mut V); N]>
) -> [Option<(&'_ K, &'_ mut V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand Down Expand Up @@ -1657,30 +1697,36 @@ where
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// (&"Bodleian Library".to_string(), &mut 1602),
/// (&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691),
/// ]),
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// Some((&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691)),
/// ],
/// );
/// // Missing keys result in None
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Gewandhaus",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(
/// got,
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// None,
/// ],
/// );
/// ```
pub unsafe fn get_many_key_value_unchecked_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[(&'_ K, &'_ mut V); N]>
) -> [Option<(&'_ K, &'_ mut V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
self.get_many_unchecked_mut_inner(ks)
.map(|res| res.map(|(k, v)| (&*k, v)))
}

fn get_many_mut_inner<Q, const N: usize>(&mut self, ks: [&Q; N]) -> Option<[&'_ mut (K, V); N]>
fn get_many_mut_inner<Q, const N: usize>(&mut self, ks: [&Q; N]) -> [Option<&'_ mut (K, V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1692,7 +1738,7 @@ where
unsafe fn get_many_unchecked_mut_inner<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[&'_ mut (K, V); N]>
) -> [Option<&'_ mut (K, V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand Down Expand Up @@ -5937,33 +5983,39 @@ mod test_map {
}

#[test]
fn test_get_each_mut() {
fn test_get_many_mut() {
let mut map = HashMap::new();
map.insert("foo".to_owned(), 0);
map.insert("bar".to_owned(), 10);
map.insert("baz".to_owned(), 20);
map.insert("qux".to_owned(), 30);

let xs = map.get_many_mut(["foo", "qux"]);
assert_eq!(xs, Some([&mut 0, &mut 30]));
assert_eq!(xs, [Some(&mut 0), Some(&mut 30)]);

let xs = map.get_many_mut(["foo", "dud"]);
assert_eq!(xs, None);

let xs = map.get_many_mut(["foo", "foo"]);
assert_eq!(xs, None);
assert_eq!(xs, [Some(&mut 0), None]);

let ys = map.get_many_key_value_mut(["bar", "baz"]);
assert_eq!(
ys,
Some([(&"bar".to_owned(), &mut 10), (&"baz".to_owned(), &mut 20),]),
[
Some((&"bar".to_owned(), &mut 10)),
Some((&"baz".to_owned(), &mut 20))
],
);

let ys = map.get_many_key_value_mut(["bar", "dip"]);
assert_eq!(ys, None);
assert_eq!(ys, [Some((&"bar".to_string(), &mut 10)), None]);
}

#[test]
#[should_panic = "duplicate keys found"]
fn test_get_many_mut_duplicate() {
let mut map = HashMap::new();
map.insert("foo".to_owned(), 0);

let ys = map.get_many_key_value_mut(["baz", "baz"]);
assert_eq!(ys, None);
let _xs = map.get_many_mut(["foo", "foo"]);
}

#[test]
Expand Down
45 changes: 22 additions & 23 deletions src/raw/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::alloc::alloc::{handle_alloc_error, Layout};
use crate::scopeguard::{guard, ScopeGuard};
use crate::TryReserveError;
use core::array;
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem;
use core::mem::MaybeUninit;
use core::ptr::NonNull;
use core::{hint, ptr};

Expand Down Expand Up @@ -484,6 +484,13 @@ impl<T> Bucket<T> {
}
}

/// Acquires the underlying non-null pointer `*mut T` to `data`.
#[inline]
fn as_non_null(&self) -> NonNull<T> {
// SAFETY: `self.ptr` is already a `NonNull`
unsafe { NonNull::new_unchecked(self.as_ptr()) }
}

/// Create a new [`Bucket`] that is offset from the `self` by the given
/// `offset`. The pointer calculation is performed by calculating the
/// offset from `self` pointer (convenience for `self.ptr.as_ptr().sub(offset)`).
Expand Down Expand Up @@ -1291,48 +1298,40 @@ impl<T, A: Allocator> RawTable<T, A> {
&mut self,
hashes: [u64; N],
eq: impl FnMut(usize, &T) -> bool,
) -> Option<[&'_ mut T; N]> {
) -> [Option<&'_ mut T>; N] {
unsafe {
let ptrs = self.get_many_mut_pointers(hashes, eq)?;
let ptrs = self.get_many_mut_pointers(hashes, eq);

for (i, &cur) in ptrs.iter().enumerate() {
if ptrs[..i].iter().any(|&prev| ptr::eq::<T>(prev, cur)) {
return None;
for (i, cur) in ptrs.iter().enumerate() {
if cur.is_some() && ptrs[..i].contains(cur) {
panic!("duplicate keys found");
}
}
// All bucket are distinct from all previous buckets so we're clear to return the result
// of the lookup.

// TODO use `MaybeUninit::array_assume_init` here instead once that's stable.
Some(mem::transmute_copy(&ptrs))
ptrs.map(|ptr| ptr.map(|mut ptr| ptr.as_mut()))
}
}

pub unsafe fn get_many_unchecked_mut<const N: usize>(
&mut self,
hashes: [u64; N],
eq: impl FnMut(usize, &T) -> bool,
) -> Option<[&'_ mut T; N]> {
let ptrs = self.get_many_mut_pointers(hashes, eq)?;
Some(mem::transmute_copy(&ptrs))
) -> [Option<&'_ mut T>; N] {
let ptrs = self.get_many_mut_pointers(hashes, eq);
ptrs.map(|ptr| ptr.map(|mut ptr| ptr.as_mut()))
}

unsafe fn get_many_mut_pointers<const N: usize>(
&mut self,
hashes: [u64; N],
mut eq: impl FnMut(usize, &T) -> bool,
) -> Option<[*mut T; N]> {
// TODO use `MaybeUninit::uninit_array` here instead once that's stable.
let mut outs: MaybeUninit<[*mut T; N]> = MaybeUninit::uninit();
let outs_ptr = outs.as_mut_ptr();

for (i, &hash) in hashes.iter().enumerate() {
let cur = self.find(hash, |k| eq(i, k))?;
*(*outs_ptr).get_unchecked_mut(i) = cur.as_mut();
}

// TODO use `MaybeUninit::array_assume_init` here instead once that's stable.
Some(outs.assume_init())
) -> [Option<NonNull<T>>; N] {
array::from_fn(|i| {
self.find(hashes[i], |k| eq(i, k))
.map(|cur| cur.as_non_null())
})
}

/// Returns the number of elements the map can hold without reallocating.
Expand Down
Loading
Loading