Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions library/alloc/src/collections/btree/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,161 @@ impl<K, V, A: Allocator + Clone> BTreeMap<K, V, A> {
)
}

/// Moves all elements from `other` into `self`, leaving `other` empty.
///
/// If a key from `other` is already present in `self`, then the `conflict`
/// closure is used to return a value to `self`. The `conflict`
/// closure takes in a borrow of `self`'s key, `self`'s value, and `other`'s value
/// in that order.
///
/// An example of why one might use this method over [`append`]
/// is to combine `self`'s value with `other`'s value when their keys conflict.
///
/// Similar to [`insert`], though, the key is not overwritten,
/// which matters for types that can be `==` without being identical.
///
///
/// [`insert`]: BTreeMap::insert
/// [`append`]: BTreeMap::append
///
/// # Examples
///
/// ```
/// #![feature(btree_merge)]
/// use std::collections::BTreeMap;
///
/// let mut a = BTreeMap::new();
/// a.insert(1, String::from("a"));
/// a.insert(2, String::from("b"));
/// a.insert(3, String::from("c")); // Note: Key (3) also present in b.
///
/// let mut b = BTreeMap::new();
/// b.insert(3, String::from("d")); // Note: Key (3) also present in a.
/// b.insert(4, String::from("e"));
/// b.insert(5, String::from("f"));
///
/// // concatenate a's value and b's value
/// a.merge(b, |_, a_val, b_val| {
/// format!("{a_val}{b_val}")
/// });
///
/// assert_eq!(a.len(), 5); // all of b's keys in a
///
/// assert_eq!(a[&1], "a");
/// assert_eq!(a[&2], "b");
/// assert_eq!(a[&3], "cd"); // Note: "c" has been combined with "d".
/// assert_eq!(a[&4], "e");
/// assert_eq!(a[&5], "f");
/// ```
#[unstable(feature = "btree_merge", issue = "152152")]
pub fn merge(&mut self, mut other: Self, mut conflict: impl FnMut(&K, V, V) -> V)
where
K: Ord,
A: Clone,
{
// Do we have to append anything at all?
if other.is_empty() {
return;
}

// We can just swap `self` and `other` if `self` is empty.
if self.is_empty() {
mem::swap(self, &mut other);
return;
}

let mut other_iter = other.into_iter();
let (first_other_key, first_other_val) = other_iter.next().unwrap();

// find the first gap that has the smallest key greater than the first key from other
let mut self_cursor = self.lower_bound_mut(Bound::Included(&first_other_key));

if let Some((self_key, self_val)) = self_cursor.peek_next() {
match K::cmp(&first_other_key, self_key) {
Ordering::Equal => {
// SAFETY: We read in self_val's and hand it over to our conflict function
// which will always return a value that we can use to overwrite what's
// in self_val
unsafe {
let val = ptr::read(self_val);
let next_val = (conflict)(self_key, val, first_other_val);
ptr::write(self_val, next_val);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should really be a method of Cursor*, you need a handler for when conflict panics that removes the entry without dropping whatever you ptr::read from.
something like:

impl<K, V> CursorMut<'a, K, V> {
    /// call `f` with the next entry's key and value, replacing the next entry's value with the returned value. if `f` unwinds, the next entry is removed.
    /// equivalent to a more efficient version of:
    /// ```rust
    /// if let Some((k, v)) = self.remove_next() {
    ///     let v = f(&k, v);
    ///     // Safety: key is unmodified
    ///     unsafe { self.insert_after_unchecked(k, v) };
    /// }
    /// ```
    pub(super) fn with_next(&mut self, f: impl FnOnce(&K, V) -> V) {
        struct RemoveNextOnDrop<'a, 'b, K, V> {
            cursor: &'a mut CursorMut<'b, K, V>,
            forget_next: bool,
        }
        impl<K, V> Drop for RemoveNextOnDrop<'_, '_, K, V> {
            fn drop(&mut self) {
                if self.forget_next {
                    // call an equivalent to CursorMut::remove_next()
                    // except that instead of returning `V`, it never moves or drops it.
                    self.0.forget_next_value();
                }
            }
        }
        let mut remove_next_on_drop = RemoveNextOnDrop {
            cursor: self,
            forget_next: false, // we don't know that we have a next value yet
        };
        if let Some((k, v_mut)) = remove_next_on_drop.cursor.peek_next() {
            remove_next_on_drop.forget_next = true;
            // Safety: we move the V out of the next entry,
            // we marked the entry's value to be forgotten
            // when remove_next_on_drop is dropped that
            // way we avoid returning to the caller leaving
            // a moved-out invalid value if `f` unwinds.
            let v = unsafe { std::ptr::read(v_mut) };
            let v = f(k, v);
            // Safety: move the V back into the next entry
            unsafe { std::ptr::write(v_mut, v) };
            remove_next_on_drop.forget_next = false;
        }
    }
}

The equivalent CursorMutKey method should instead have f be impl FnOnce(K, V) -> (K, V) and needs to forget both the key and value since they were both ptr::read, and not just the value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you 👍

}
Ordering::Less =>
// SAFETY: we know our other_key's ordering is less than self_key,
// so inserting before will guarantee sorted order
unsafe {
self_cursor.insert_before_unchecked(first_other_key, first_other_val);
},
Ordering::Greater => {
unreachable!("cursor's peek_next would return None in this case");
}
}
} else {
// SAFETY: if we reach here, that means our cursor has reached
// the end of self BTreeMap, (other_key is greater than all the
// previous self BTreeMap keys) so we just insert other_key here
// at the end of the CursorMut
unsafe {
self_cursor.insert_after_unchecked(first_other_key, first_other_val);
}
}

for (other_key, other_val) in other_iter {
if self_cursor.peek_next().is_some() {
loop {
let self_entry = self_cursor.peek_next();
if let Some((self_key, self_val)) = self_entry {
match K::cmp(&other_key, self_key) {
Ordering::Equal => {
// SAFETY: We read in self_val's and hand it over to our conflict function
// which will always return a value that we can use to overwrite what's
// in self_val
unsafe {
let val = ptr::read(self_val);
let next_val = (conflict)(self_key, val, other_val);
ptr::write(self_val, next_val);
}
break;
}
Ordering::Less => {
// SAFETY: we know our other_key's ordering is less than self_key,
// so inserting before will guarantee sorted order
unsafe {
self_cursor.insert_before_unchecked(other_key, other_val);
}
break;
}
Ordering::Greater => {
self_cursor.next();
}
}
} else {
// SAFETY: if we reach here, that means our cursor has reached
// the end of self BTreeMap, (other_key is greater than all the
// previous self BTreeMap keys) so we just insert other_key here
// at the end of the Cursor
unsafe {
self_cursor.insert_after_unchecked(other_key, other_val);
}
self_cursor.next();
break;
}
}
} else {
// SAFETY: if we reach here, that means our cursor has reached
// the end of self BTreeMap, (other_key is greater than all the
// previous self BTreeMap keys) so we just insert the rest of
// other_keys here at the end of CursorMut
unsafe {
self_cursor.insert_after_unchecked(other_key, other_val);
}
self_cursor.next();
}
}
}

/// Constructs a double-ended iterator over a sub-range of elements in the map.
/// The simplest way is to use the range syntax `min..max`, thus `range(min..max)` will
/// yield elements from min (inclusive) to max (exclusive).
Expand Down
88 changes: 87 additions & 1 deletion library/alloc/src/collections/btree/map/tests.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use core::assert_matches;
use std::iter;
use std::ops::Bound::{Excluded, Included, Unbounded};
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::SeqCst;
use std::{cmp, iter};

use super::*;
use crate::boxed::Box;
Expand Down Expand Up @@ -2128,6 +2128,76 @@ create_append_test!(test_append_239, 239);
#[cfg(not(miri))] // Miri is too slow
create_append_test!(test_append_1700, 1700);

macro_rules! create_merge_test {
($name:ident, $len:expr) => {
#[test]
fn $name() {
let mut a = BTreeMap::new();
for i in 0..8 {
a.insert(i, i);
}

let mut b = BTreeMap::new();
for i in 5..$len {
b.insert(i, 2 * i);
}

a.merge(b, |_, a_val, b_val| a_val + b_val);

assert_eq!(a.len(), cmp::max($len, 8));

for i in 0..cmp::max($len, 8) {
if i < 5 {
assert_eq!(a[&i], i);
} else {
if i < cmp::min($len, 8) {
assert_eq!(a[&i], i + 2 * i);
} else if i >= $len {
assert_eq!(a[&i], i);
} else {
assert_eq!(a[&i], 2 * i);
}
}
}

a.check();
assert_eq!(
a.remove(&($len - 1)),
if $len >= 5 && $len < 8 {
Some(($len - 1) + 2 * ($len - 1))
} else {
Some(2 * ($len - 1))
}
);
assert_eq!(a.insert($len - 1, 20), None);
a.check();
}
};
}

// These are mostly for testing the algorithm that "fixes" the right edge after insertion.
// Single node, merge conflicting key values.
create_merge_test!(test_merge_7, 7);
// Single node.
create_merge_test!(test_merge_9, 9);
// Two leafs that don't need fixing.
create_merge_test!(test_merge_17, 17);
// Two leafs where the second one ends up underfull and needs stealing at the end.
create_merge_test!(test_merge_14, 14);
// Two leafs where the second one ends up empty because the insertion finished at the root.
create_merge_test!(test_merge_12, 12);
// Three levels; insertion finished at the root.
create_merge_test!(test_merge_144, 144);
// Three levels; insertion finished at leaf while there is an empty node on the second level.
create_merge_test!(test_merge_145, 145);
// Tests for several randomly chosen sizes.
create_merge_test!(test_merge_170, 170);
create_merge_test!(test_merge_181, 181);
#[cfg(not(miri))] // Miri is too slow
create_merge_test!(test_merge_239, 239);
#[cfg(not(miri))] // Miri is too slow
create_merge_test!(test_merge_1700, 1700);

#[test]
#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")]
fn test_append_drop_leak() {
Expand Down Expand Up @@ -2615,3 +2685,19 @@ fn test_id_based_append() {

assert_eq!(lhs.pop_first().unwrap().0.name, "lhs_k".to_string());
}

#[test]
fn test_id_based_merge() {
let mut lhs = BTreeMap::new();
let mut rhs = BTreeMap::new();

lhs.insert(IdBased { id: 0, name: "lhs_k".to_string() }, "1".to_string());
rhs.insert(IdBased { id: 0, name: "rhs_k".to_string() }, "2".to_string());

lhs.merge(rhs, |_, mut lhs_val, rhs_val| {
lhs_val.push_str(&rhs_val);
lhs_val
});

assert_eq!(lhs.pop_first().unwrap().0.name, "lhs_k".to_string());
}
1 change: 1 addition & 0 deletions library/alloctests/tests/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(allocator_api)]
#![feature(binary_heap_pop_if)]
#![feature(btree_merge)]
#![feature(const_heap)]
#![feature(deque_extend_front)]
#![feature(iter_array_chunks)]
Expand Down
Loading