Skip to content

Commit a7fa7dd

Browse files
committed
optimized version of BTreeMap::merge with CursorMut and unsafe code
1 parent 1b29082 commit a7fa7dd

File tree

2 files changed

+91
-71
lines changed

2 files changed

+91
-71
lines changed

library/alloc/src/collections/btree/append.rs

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -33,36 +33,6 @@ impl<K, V> Root<K, V> {
3333
self.bulk_push(iter, length, alloc)
3434
}
3535

36-
/// Merges all key-value pairs from the union of two ascending iterators,
37-
/// incrementing a `length` variable along the way. The latter makes it
38-
/// easier for the caller to avoid a leak when a drop handler panicks.
39-
///
40-
/// If both iterators produce the same key, this method constructs a pair using the
41-
/// key from the left iterator and calls on a closure `f` to return a value given
42-
/// the conflicting key and value from left and right iterators.
43-
///
44-
/// If you want the tree to end up in a strictly ascending order, like for
45-
/// a `BTreeMap`, both iterators should produce keys in strictly ascending
46-
/// order, each greater than all keys in the tree, including any keys
47-
/// already in the tree upon entry.
48-
pub(super) fn merge_from_sorted_iters_with<I, A: Allocator + Clone>(
49-
&mut self,
50-
left: I,
51-
right: I,
52-
length: &mut usize,
53-
alloc: A,
54-
f: impl FnMut(&K, V, V) -> V,
55-
) where
56-
K: Ord,
57-
I: Iterator<Item = (K, V)> + FusedIterator,
58-
{
59-
// We prepare to merge `left` and `right` into a sorted sequence in linear time.
60-
let iter = MergeIterWith { inner: MergeIterInner::new(left, right), f };
61-
62-
// Meanwhile, we build a tree from the sorted sequence in linear time.
63-
self.bulk_push(iter, length, alloc)
64-
}
65-
6636
/// Pushes all key-value pairs to the end of the tree, incrementing a
6737
/// `length` variable along the way. The latter makes it easier for the
6838
/// caller to avoid a leak when the iterator panicks.
@@ -145,33 +115,3 @@ where
145115
}
146116
}
147117
}
148-
149-
/// An iterator for merging two sorted sequences into one with
150-
/// a callback function to return a value on conflicting keys
151-
struct MergeIterWith<F, K, V, I: Iterator<Item = (K, V)>> {
152-
inner: MergeIterInner<I>,
153-
f: F,
154-
}
155-
156-
impl<F, K: Ord, V, I> Iterator for MergeIterWith<F, K, V, I>
157-
where
158-
F: FnMut(&K, V, V) -> V,
159-
I: Iterator<Item = (K, V)> + FusedIterator,
160-
{
161-
type Item = (K, V);
162-
163-
/// If two keys are equal, returns the key from the left and uses `f` to return
164-
/// a value given the conflicting key and values from left and right
165-
fn next(&mut self) -> Option<(K, V)> {
166-
let (a_next, b_next) = self.inner.nexts(|a: &(K, V), b: &(K, V)| K::cmp(&a.0, &b.0));
167-
match (a_next, b_next) {
168-
(Some((a_k, a_v)), Some((_, b_v))) => Some({
169-
let next_val = (self.f)(&a_k, a_v, b_v);
170-
(a_k, next_val)
171-
}),
172-
(Some(a), None) => Some(a),
173-
(None, Some(b)) => Some(b),
174-
(None, None) => None,
175-
}
176-
}
177-
}

library/alloc/src/collections/btree/map.rs

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ impl<K, V, A: Allocator + Clone> BTreeMap<K, V, A> {
12871287
/// assert_eq!(a[&5], "f");
12881288
/// ```
12891289
#[unstable(feature = "btree_merge", issue = "152152")]
1290-
pub fn merge(&mut self, mut other: Self, conflict: impl FnMut(&K, V, V) -> V)
1290+
pub fn merge(&mut self, mut other: Self, mut conflict: impl FnMut(&K, V, V) -> V)
12911291
where
12921292
K: Ord,
12931293
A: Clone,
@@ -1303,16 +1303,96 @@ impl<K, V, A: Allocator + Clone> BTreeMap<K, V, A> {
13031303
return;
13041304
}
13051305

1306-
let self_iter = mem::replace(self, Self::new_in((*self.alloc).clone())).into_iter();
1307-
let other_iter = mem::replace(&mut other, Self::new_in((*self.alloc).clone())).into_iter();
1308-
let root = self.root.get_or_insert_with(|| Root::new((*self.alloc).clone()));
1309-
root.merge_from_sorted_iters_with(
1310-
self_iter,
1311-
other_iter,
1312-
&mut self.length,
1313-
(*self.alloc).clone(),
1314-
conflict,
1315-
)
1306+
let mut other_iter = other.into_iter();
1307+
let (first_other_key, first_other_val) = other_iter.next().unwrap();
1308+
1309+
// find the first gap that has the smallest key greater than the first key from other
1310+
let mut self_cursor = self.lower_bound_mut(Bound::Included(&first_other_key));
1311+
1312+
if let Some((self_key, self_val)) = self_cursor.peek_next() {
1313+
match K::cmp(&first_other_key, self_key) {
1314+
Ordering::Equal => {
1315+
// SAFETY: We read in self_val's and hand it over to our conflict function
1316+
// which will always return a value that we can use to overwrite what's
1317+
// in self_val
1318+
unsafe {
1319+
let val = ptr::read(self_val);
1320+
let next_val = (conflict)(self_key, val, first_other_val);
1321+
ptr::write(self_val, next_val);
1322+
}
1323+
}
1324+
Ordering::Less =>
1325+
// SAFETY: we know our other_key's ordering is less than self_key,
1326+
// so inserting before will guarantee sorted order
1327+
unsafe {
1328+
self_cursor.insert_before_unchecked(first_other_key, first_other_val);
1329+
},
1330+
Ordering::Greater => {
1331+
unreachable!("cursor's peek_next would return None in this case");
1332+
}
1333+
}
1334+
} else {
1335+
// SAFETY: if we reach here, that means our cursor has reached
1336+
// the end of self BTreeMap, (other_key is greater than all the
1337+
// previous self BTreeMap keys) so we just insert other_key here
1338+
// at the end of the CursorMut
1339+
unsafe {
1340+
self_cursor.insert_after_unchecked(first_other_key, first_other_val);
1341+
}
1342+
}
1343+
1344+
for (other_key, other_val) in other_iter {
1345+
if self_cursor.peek_next().is_some() {
1346+
loop {
1347+
let self_entry = self_cursor.peek_next();
1348+
if let Some((self_key, self_val)) = self_entry {
1349+
match K::cmp(&other_key, self_key) {
1350+
Ordering::Equal => {
1351+
// SAFETY: We read in self_val's and hand it over to our conflict function
1352+
// which will always return a value that we can use to overwrite what's
1353+
// in self_val
1354+
unsafe {
1355+
let val = ptr::read(self_val);
1356+
let next_val = (conflict)(self_key, val, other_val);
1357+
ptr::write(self_val, next_val);
1358+
}
1359+
break;
1360+
}
1361+
Ordering::Less => {
1362+
// SAFETY: we know our other_key's ordering is less than self_key,
1363+
// so inserting before will guarantee sorted order
1364+
unsafe {
1365+
self_cursor.insert_before_unchecked(other_key, other_val);
1366+
}
1367+
break;
1368+
}
1369+
Ordering::Greater => {
1370+
self_cursor.next();
1371+
}
1372+
}
1373+
} else {
1374+
// SAFETY: if we reach here, that means our cursor has reached
1375+
// the end of self BTreeMap, (other_key is greater than all the
1376+
// previous self BTreeMap keys) so we just insert other_key here
1377+
// at the end of the Cursor
1378+
unsafe {
1379+
self_cursor.insert_after_unchecked(other_key, other_val);
1380+
}
1381+
self_cursor.next();
1382+
break;
1383+
}
1384+
}
1385+
} else {
1386+
// SAFETY: if we reach here, that means our cursor has reached
1387+
// the end of self BTreeMap, (other_key is greater than all the
1388+
// previous self BTreeMap keys) so we just insert the rest of
1389+
// other_keys here at the end of CursorMut
1390+
unsafe {
1391+
self_cursor.insert_after_unchecked(other_key, other_val);
1392+
}
1393+
self_cursor.next();
1394+
}
1395+
}
13161396
}
13171397

13181398
/// Constructs a double-ended iterator over a sub-range of elements in the map.

0 commit comments

Comments
 (0)