Skip to content

Commit 62c4028

Browse files
committed
arrow-select: add support for merging primitive dictionary values
Previously, should_merge_dictionaries would always return false in the ptr_eq closure creation match arm for types that were not {Large}{Utf8,Binary}. This could lead to excessive memory usage.
1 parent 1a5999a commit 62c4028

File tree

2 files changed

+76
-7
lines changed

2 files changed

+76
-7
lines changed

arrow-select/src/concat.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,43 @@ mod tests {
10101010
assert!((30..40).contains(&values_len), "{values_len}")
10111011
}
10121012

1013+
#[test]
1014+
fn test_primitive_dictionary_merge() {
1015+
// Same value repeated 5 times.
1016+
let keys = vec![1; 5];
1017+
let values = (10..20).collect::<Vec<_>>();
1018+
let dict = DictionaryArray::new(
1019+
Int8Array::from(keys.clone()),
1020+
Arc::new(Int32Array::from(values.clone())),
1021+
);
1022+
let other = DictionaryArray::new(
1023+
Int8Array::from(keys.clone()),
1024+
Arc::new(Int32Array::from(values.clone())),
1025+
);
1026+
1027+
let result_same_dictionary = concat(&[&dict, &dict]).unwrap();
1028+
// Verify pointer equality check succeeds, and therefore the
1029+
// dictionaries are not merged. A single values buffer should be reused
1030+
// in this case.
1031+
assert_eq!(
1032+
result_same_dictionary
1033+
.as_dictionary::<Int8Type>()
1034+
.values()
1035+
.len(),
1036+
values.len(),
1037+
);
1038+
1039+
let result_cloned_dictionary = concat(&[&dict, &other]).unwrap();
1040+
// Should have only 1 underlying value since all keys reference it.
1041+
assert_eq!(
1042+
result_cloned_dictionary
1043+
.as_dictionary::<Int8Type>()
1044+
.values()
1045+
.len(),
1046+
1
1047+
);
1048+
}
1049+
10131050
#[test]
10141051
fn test_concat_string_sizes() {
10151052
let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();

arrow-select/src/dictionary.rs

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
use crate::interleave::interleave;
1919
use ahash::RandomState;
2020
use arrow_array::builder::BooleanBufferBuilder;
21-
use arrow_array::cast::AsArray;
2221
use arrow_array::types::{
23-
ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type,
22+
ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
23+
LargeUtf8Type, Utf8Type,
2424
};
25-
use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray};
26-
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer};
25+
use arrow_array::{cast::AsArray, downcast_primitive};
26+
use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray};
27+
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
2728
use arrow_schema::{ArrowError, DataType};
2829

2930
/// A best effort interner that maintains a fixed number of buckets
@@ -120,7 +121,12 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
120121
LargeUtf8 => Box::new(bytes_ptr_eq::<LargeUtf8Type>),
121122
Binary => Box::new(bytes_ptr_eq::<BinaryType>),
122123
LargeBinary => Box::new(bytes_ptr_eq::<LargeBinaryType>),
123-
_ => return false,
124+
dt => {
125+
if !dt.is_primitive() {
126+
return false;
127+
}
128+
Box::new(|a, b| a.to_data().ptr_eq(&b.to_data()))
129+
}
124130
};
125131

126132
let mut single_dictionary = true;
@@ -233,17 +239,43 @@ fn compute_values_mask<K: ArrowNativeType>(
233239
builder.finish()
234240
}
235241

242+
/// Process primitive array values to bytes
243+
fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
244+
array: &'a PrimitiveArray<T>,
245+
mask: &BooleanBuffer,
246+
) -> Vec<(usize, Option<&'a [u8]>)>
247+
where
248+
T::Native: ToByteSlice,
249+
{
250+
let mut out = Vec::with_capacity(mask.count_set_bits());
251+
let values = array.values();
252+
for idx in mask.set_indices() {
253+
out.push((
254+
idx,
255+
array.is_valid(idx).then_some(values[idx].to_byte_slice()),
256+
))
257+
}
258+
out
259+
}
260+
261+
macro_rules! masked_primitive_to_bytes_helper {
262+
($t:ty, $array:expr, $mask:expr) => {
263+
masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
264+
};
265+
}
266+
236267
/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
237268
fn get_masked_values<'a>(
238269
array: &'a dyn Array,
239270
mask: &BooleanBuffer,
240271
) -> Vec<(usize, Option<&'a [u8]>)> {
241-
match array.data_type() {
272+
downcast_primitive! {
273+
array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
242274
DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
243275
DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
244276
DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
245277
DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
246-
_ => unimplemented!(),
278+
_ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
247279
}
248280
}
249281

0 commit comments

Comments
 (0)