diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs new file mode 100644 index 0000000000..f0a1b00737 --- /dev/null +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{array::*, datatypes::DataType}; +use datafusion::common::Result; +use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; +use std::sync::Arc; + +macro_rules! compute_op { + ($OPERAND:expr, $DT:ident) => {{ + let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| { + DataFusionError::Execution(format!( + "compute_op failed to downcast array to: {:?}", + stringify!($DT) + )) + })?; + + let result: Int32Array = operand + .iter() + .map(|x| x.map(|y| bit_count(y.into()))) + .collect(); + + Ok(Arc::new(result)) + }}; +} + +pub fn spark_bit_count(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "bit_count expects exactly one argument".to_string(), + )); + } + match &args[0] { + ColumnarValue::Array(array) => { + let result: Result = match array.data_type() { + DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array), + DataType::Int16 => compute_op!(array, Int16Array), + DataType::Int32 => compute_op!(array, Int32Array), + DataType::Int64 => compute_op!(array, Int64Array), + _ => Err(DataFusionError::Execution(format!( + "Can't be evaluated because the expression's type is {:?}, not signed int", + array.data_type(), + ))), + }; + result.map(ColumnarValue::Array) + } + ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( + "shouldn't go to bit_count scalar path".to_string(), + )), + } +} + +// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType) +fn bit_count(i: i64) -> i32 { + let mut u = i as u64; + u = u - ((u >> 1) & 0x5555555555555555); + u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333); + u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f; + u = u + (u >> 8); + u = u + (u >> 16); + u = u + (u >> 32); + (u as i32) & 0x7f +} + +#[cfg(test)] +mod tests { + use datafusion::common::{cast::as_int32_array, Result}; + + use super::*; + + #[test] + fn bitwise_count_op() -> Result<()> { + let args = vec![ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(12345), + Some(89), + Some(-3456), + ])))]; + let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]); + + let ColumnarValue::Array(result) = spark_bit_count(&args)? else { + unreachable!() + }; + + let result = as_int32_array(&result).expect("failed to downcast to In32Array"); + assert_eq!(result, expected); + + Ok(()) + } +} diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs index 9c26363319..718cfc7ca8 100644 --- a/native/spark-expr/src/bitwise_funcs/mod.rs +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod bitwise_count; mod bitwise_not; +pub use bitwise_count::spark_bit_count; pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index cf06d36332..f852060008 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -17,10 +17,10 @@ use crate::hash_funcs::*; use crate::{ - spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, - spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, - spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, - SparkChrFunc, + spark_array_repeat, spark_bit_count, spark_ceil, spark_date_add, spark_date_sub, + spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, + spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, + spark_unscaled_value, SparkChrFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -145,6 +145,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_array_repeat); make_comet_scalar_udf!("array_repeat", func, without data_type) } + "bit_count" => { + let func = Arc::new(spark_bit_count); + make_comet_scalar_udf!("bit_count", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a98234585d..5d988001fa 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1712,6 +1712,12 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) + case BitwiseCount(child) => + val childProto = exprToProto(child, inputs, binding) + val bitCountScalarExpr = + scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto) + optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*) + case ShiftRight(left, right) => // DataFusion bitwise shift right expression requires // same data type between left and right side diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 77c55daa1f..0fa4ea7148 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types.{Decimal, DecimalType} import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -99,6 +100,73 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("bitwise_count - min/max values") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "bitwise_count_test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet") + sql(s"insert into $table values(1111, 2222, 17, 7)") + sql( + s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, ${Short.MaxValue}, ${Byte.MaxValue})") + sql( + s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, ${Short.MinValue}, ${Byte.MinValue})") + + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM $table")) + } + } + } + } + + test("bitwise_count - random values (spark gen)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 10, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + val df = + table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", "bit_count(c4)") + + checkSparkAnswerAndOperator(df) + } + } + + test("bitwise_count - random values (native parquet gen)") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 0, 10000, nullEnabled = false) + val table = spark.read.parquet(path.toString) + checkSparkAnswerAndOperator( + table + .selectExpr( + "bit_count(_2)", + "bit_count(_3)", + "bit_count(_4)", + "bit_count(_5)", + "bit_count(_11)")) + } + } + } + test("bitwise shift with different left/right types") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {