|
18 | 18 | use crate::interleave::interleave;
|
19 | 19 | use ahash::RandomState;
|
20 | 20 | use arrow_array::builder::BooleanBufferBuilder;
|
21 |
| -use arrow_array::cast::AsArray; |
22 | 21 | use arrow_array::types::{
|
23 |
| - ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type, |
| 22 | + ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType, |
| 23 | + LargeUtf8Type, Utf8Type, |
24 | 24 | };
|
25 |
| -use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray}; |
26 |
| -use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer}; |
| 25 | +use arrow_array::{cast::AsArray, downcast_primitive}; |
| 26 | +use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray}; |
| 27 | +use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice}; |
27 | 28 | use arrow_schema::{ArrowError, DataType};
|
28 | 29 |
|
29 | 30 | /// A best effort interner that maintains a fixed number of buckets
|
@@ -120,7 +121,12 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
|
120 | 121 | LargeUtf8 => Box::new(bytes_ptr_eq::<LargeUtf8Type>),
|
121 | 122 | Binary => Box::new(bytes_ptr_eq::<BinaryType>),
|
122 | 123 | LargeBinary => Box::new(bytes_ptr_eq::<LargeBinaryType>),
|
123 |
| - _ => return false, |
| 124 | + dt => { |
| 125 | + if !dt.is_primitive() { |
| 126 | + return false; |
| 127 | + } |
| 128 | + Box::new(|a, b| a.to_data().ptr_eq(&b.to_data())) |
| 129 | + } |
124 | 130 | };
|
125 | 131 |
|
126 | 132 | let mut single_dictionary = true;
|
@@ -233,17 +239,43 @@ fn compute_values_mask<K: ArrowNativeType>(
|
233 | 239 | builder.finish()
|
234 | 240 | }
|
235 | 241 |
|
| 242 | +/// Process primitive array values to bytes |
| 243 | +fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>( |
| 244 | + array: &'a PrimitiveArray<T>, |
| 245 | + mask: &BooleanBuffer, |
| 246 | +) -> Vec<(usize, Option<&'a [u8]>)> |
| 247 | +where |
| 248 | + T::Native: ToByteSlice, |
| 249 | +{ |
| 250 | + let mut out = Vec::with_capacity(mask.count_set_bits()); |
| 251 | + let values = array.values(); |
| 252 | + for idx in mask.set_indices() { |
| 253 | + out.push(( |
| 254 | + idx, |
| 255 | + array.is_valid(idx).then_some(values[idx].to_byte_slice()), |
| 256 | + )) |
| 257 | + } |
| 258 | + out |
| 259 | +} |
| 260 | + |
| 261 | +macro_rules! masked_primitive_to_bytes_helper { |
| 262 | + ($t:ty, $array:expr, $mask:expr) => { |
| 263 | + masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask) |
| 264 | + }; |
| 265 | +} |
| 266 | + |
236 | 267 | /// Return a Vec containing for each set index in `mask`, the index and byte value of that index
|
237 | 268 | fn get_masked_values<'a>(
|
238 | 269 | array: &'a dyn Array,
|
239 | 270 | mask: &BooleanBuffer,
|
240 | 271 | ) -> Vec<(usize, Option<&'a [u8]>)> {
|
241 |
| - match array.data_type() { |
| 272 | + downcast_primitive! { |
| 273 | + array.data_type() => (masked_primitive_to_bytes_helper, array, mask), |
242 | 274 | DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
|
243 | 275 | DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
|
244 | 276 | DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
|
245 | 277 | DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
|
246 |
| - _ => unimplemented!(), |
| 278 | + _ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()), |
247 | 279 | }
|
248 | 280 | }
|
249 | 281 |
|
|
0 commit comments