From 5a87c3bf431620c29724420a8c8442d6623a27e9 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Mon, 31 Mar 2025 21:42:18 +0400 Subject: [PATCH 01/11] bitwise_count wip --- native/core/src/execution/planner.rs | 6 +- native/proto/src/proto/expr.proto | 1 + .../src/bitwise_funcs/bitwise_count.rs | 156 ++++++++++++++++++ native/spark-expr/src/bitwise_funcs/mod.rs | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 8 + .../apache/comet/CometBitwiseCountSuite.scala | 43 +++++ .../apache/comet/CometExpressionSuite.scala | 15 ++ 7 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 native/spark-expr/src/bitwise_funcs/bitwise_count.rs create mode 100644 spark/src/test/scala/org/apache/comet/CometBitwiseCountSuite.scala diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 60803dfeb8..2420171ca9 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -65,7 +65,7 @@ use datafusion::{ }, prelude::SessionContext, }; -use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}; +use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr, BitwiseCountExpr}; use crate::execution::operators::ExecutionError::GeneralError; use crate::execution::shuffle::CompressionCodec; @@ -616,6 +616,10 @@ impl PhysicalPlanner { let op = DataFusionOperator::BitwiseShiftLeft; Ok(Arc::new(BinaryExpr::new(left, op, right))) } + ExprStruct::BitwiseCount(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; + Ok(Arc::new(BitwiseCountExpr::new(child))) + } // https://github.com/apache/datafusion-comet/issues/666 // ExprStruct::Abs(expr) => { // let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 90fd08948c..634900a089 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -84,6 +84,7 @@ message Expr { GetArrayStructFields get_array_struct_fields = 57; ArrayInsert array_insert = 58; MathExpr integral_divide = 59; + UnaryExpr bitwiseCount = 60; } } diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs new file mode 100644 index 0000000000..0bb64c6338 --- /dev/null +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::hash::Hash; +use std::sync::Arc; +use arrow::{ + array::*, + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion::common::Result; +use datafusion::common::DataFusionError; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; + +macro_rules! check_overflow { + ($VALUE:expr, $TYPE:expr, $TYPE_NAME:expr) => {{ + if $VALUE == $TYPE::MIN { + if $TYPE_NAME == "byte" || $TYPE_NAME == "short" { + let msg = format!("{:?} caused", $VALUE); + return Err(arithmetic_overflow_error(msg.as_str()).into()); + } + return Err(arithmetic_overflow_error($TYPE_NAME).into()); + } + }}; +} + +macro_rules! compute_op { + ($OPERAND:expr, $DT:ident, $TYPE:expr, $TYPE_NAME:expr) => {{ + let operand = $OPERAND + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + + let result: $DT = operand + .iter() + .map(|x| { x.map(|y| { check_overflow!(y.count_ones(), $TYPE, $TYPE_NAME) as $TYPE })}) + .collect(); + + Ok(Arc::new(result)) + }}; +} + +/// BitwiseCount expression +#[derive(Debug, Eq)] +pub struct BitwiseCountExpr { + /// Input expression + arg: Arc, +} + +impl Hash for BitwiseCountExpr { + fn hash(&self, state: &mut H) { + self.arg.hash(state); + } +} + +impl PartialEq for BitwiseCountExpr { + fn eq(&self, other: &Self) -> bool { + self.arg.eq(&other.arg) + } +} + +impl BitwiseCountExpr { + /// Create new bitwise count expression + pub fn new(arg: Arc) -> Self { + Self { arg } + } + + /// Get the input expression + pub fn arg(&self) -> &Arc { + &self.arg + } +} + +impl std::fmt::Display for BitwiseCountExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "(~ {})", self.arg) + } +} + +impl PhysicalExpr for BitwiseCountExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result { + self.arg.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.arg.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg = self.arg.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let result: Result = match array.data_type() { + DataType::Int8 => compute_op!(array, Int8Array, i8, "byte"), + DataType::Int16 => compute_op!(array, Int16Array, i16, "short"), + DataType::Int32 => compute_op!(array, Int32Array, i32, "integer"), + DataType::Int64 => compute_op!(array, Int64Array, i64, "long"), + _ => Err(DataFusionError::Execution(format!( + "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed int", + self, + array.data_type(), + ))), + }; + result.map(ColumnarValue::Array) + } + ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( + "shouldn't go to bitwise not scalar path".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.arg] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(BitwiseCountExpr::new(Arc::clone(&children[0])))) + } +} + +pub fn bitwise_count(arg: Arc) -> Result> { + Ok(Arc::new(BitwiseCountExpr::new(arg))) +} + +#[cfg(test)] +mod tests { + + #[test] + fn bitwise_count_op() -> datafusion::common::Result<()> { + Ok(()) + } +} \ No newline at end of file diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs index 9c26363319..4c75501e51 100644 --- a/native/spark-expr/src/bitwise_funcs/mod.rs +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -16,5 +16,7 @@ // under the License. mod bitwise_not; +mod bitwise_count; pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; +pub use bitwise_count::{bitwise_count, BitwiseCountExpr}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a8a3df0c17..48e55acb0b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1652,6 +1652,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim binding, (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) + case BitwiseCount(child) => + createUnaryExpr( + expr, + child, + inputs, + binding, + (builder, unaryExpr) => builder.setBitwiseCount(unaryExpr)) + case ShiftRight(left, right) => // DataFusion bitwise shift right expression requires // same data type between left and right side diff --git a/spark/src/test/scala/org/apache/comet/CometBitwiseCountSuite.scala b/spark/src/test/scala/org/apache/comet/CometBitwiseCountSuite.scala new file mode 100644 index 0000000000..fbb2c2e87c --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometBitwiseCountSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase + +class CometBitwiseCountSuite extends CometTestBase { + + test("bitwise_count") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "bitwise_count_test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet") + sql(s"insert into $table values(1111, 2222, 17, 7)") + + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table")) + // checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col5) FROM $table")) + } + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index c3bd2efef3..0e7bc46837 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -90,6 +90,21 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("bitwise_count") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "bitwise_count_test" + withTable(table) { + sql(s"create table $table(col1 long) using parquet") + sql(s"insert into $table values(1111)") + sql(s"insert into $table values(1111)") + + checkSparkAnswer(sql(s"SELECT bit_count(col1) FROM $table")) + } + } + } + } + test("bitwise shift with different left/right types") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { From 6bb6cb11350c33c59188999c9f924d0180bf53cd Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Wed, 2 Apr 2025 22:10:18 +0400 Subject: [PATCH 02/11] bitwise_count impl --- native/core/src/execution/planner.rs | 4 +- .../src/bitwise_funcs/bitwise_count.rs | 81 +++++++++++++------ .../apache/comet/CometBitwiseCountSuite.scala | 43 ---------- .../apache/comet/CometExpressionSuite.scala | 18 +++-- 4 files changed, 70 insertions(+), 76 deletions(-) delete mode 100644 spark/src/test/scala/org/apache/comet/CometBitwiseCountSuite.scala diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 2420171ca9..19cb92a443 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -65,7 +65,7 @@ use datafusion::{ }, prelude::SessionContext, }; -use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr, BitwiseCountExpr}; +use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}; use crate::execution::operators::ExecutionError::GeneralError; use crate::execution::shuffle::CompressionCodec; @@ -103,7 +103,7 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation, + ArrayInsert, Avg, AvgDecimal, BitwiseCountExpr, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation, Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField, HourExpr, IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr, SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index 0bb64c6338..9a43ab1f19 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -15,33 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::hash::Hash; -use std::sync::Arc; use arrow::{ array::*, datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion::common::Result; -use datafusion::common::DataFusionError; -use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; - -macro_rules! check_overflow { - ($VALUE:expr, $TYPE:expr, $TYPE_NAME:expr) => {{ - if $VALUE == $TYPE::MIN { - if $TYPE_NAME == "byte" || $TYPE_NAME == "short" { - let msg = format!("{:?} caused", $VALUE); - return Err(arithmetic_overflow_error(msg.as_str()).into()); - } - return Err(arithmetic_overflow_error($TYPE_NAME).into()); - } - }}; -} +use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; macro_rules! compute_op { - ($OPERAND:expr, $DT:ident, $TYPE:expr, $TYPE_NAME:expr) => {{ + ($OPERAND:expr, $DT:ident, $TY:ty) => {{ let operand = $OPERAND .as_any() .downcast_ref::<$DT>() @@ -49,7 +35,7 @@ macro_rules! compute_op { let result: $DT = operand .iter() - .map(|x| { x.map(|y| { check_overflow!(y.count_ones(), $TYPE, $TYPE_NAME) as $TYPE })}) + .map(|x| { x.map(|y| { bit_count(y.into()) as $TY })}) .collect(); Ok(Arc::new(result)) @@ -99,7 +85,7 @@ impl PhysicalExpr for BitwiseCountExpr { self } - fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result { + fn data_type(&self, input_schema: &Schema) -> Result { self.arg.data_type(input_schema) } @@ -112,10 +98,10 @@ impl PhysicalExpr for BitwiseCountExpr { match arg { ColumnarValue::Array(array) => { let result: Result = match array.data_type() { - DataType::Int8 => compute_op!(array, Int8Array, i8, "byte"), - DataType::Int16 => compute_op!(array, Int16Array, i16, "short"), - DataType::Int32 => compute_op!(array, Int32Array, i32, "integer"), - DataType::Int64 => compute_op!(array, Int64Array, i64, "long"), + DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array, i8), + DataType::Int16 => compute_op!(array, Int16Array, i16), + DataType::Int32 => compute_op!(array, Int32Array, i32), + DataType::Int64 => compute_op!(array, Int64Array, i64), _ => Err(DataFusionError::Execution(format!( "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed int", self, @@ -125,7 +111,7 @@ impl PhysicalExpr for BitwiseCountExpr { result.map(ColumnarValue::Array) } ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( - "shouldn't go to bitwise not scalar path".to_string(), + "shouldn't go to bitwise count scalar path".to_string(), )), } } @@ -146,11 +132,54 @@ pub fn bitwise_count(arg: Arc) -> Result Ok(Arc::new(BitwiseCountExpr::new(arg))) } +// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType) +fn bit_count(i: i64) -> i32 { + let mut u = i as u64; + u = u - ((u >> 1) & 0x5555555555555555); + u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333); + u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f; + u = u + (u >> 8); + u = u + (u >> 16); + u = u + (u >> 32); + (u as i32) & 0x7f +} + #[cfg(test)] mod tests { + use arrow::datatypes::*; + use datafusion::common::{cast::as_int32_array, Result}; + use datafusion::physical_expr::expressions::col; + + use super::*; #[test] - fn bitwise_count_op() -> datafusion::common::Result<()> { + fn bitwise_count_op() -> Result<()> { + let schema = Schema::new(vec![Field::new("field", DataType::Int32, true)]); + + let expr = bitwise_count(col("field", &schema)?)?; + + let input = Int32Array::from(vec![ + Some(1), + None, + Some(12345), + Some(89), + Some(-3456), + ]); + + let expected = &Int32Array::from(vec![ + Some(1), + None, + Some(6), + Some(4), + Some(54), + ]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_int32_array(&result).expect("failed to downcast to In32Array"); + assert_eq!(result, expected); + Ok(()) } } \ No newline at end of file diff --git a/spark/src/test/scala/org/apache/comet/CometBitwiseCountSuite.scala b/spark/src/test/scala/org/apache/comet/CometBitwiseCountSuite.scala deleted file mode 100644 index fbb2c2e87c..0000000000 --- a/spark/src/test/scala/org/apache/comet/CometBitwiseCountSuite.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet - -import org.apache.spark.sql.CometTestBase - -class CometBitwiseCountSuite extends CometTestBase { - - test("bitwise_count") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "bitwise_count_test" - withTable(table) { - sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet") - sql(s"insert into $table values(1111, 2222, 17, 7)") - - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table")) - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table")) - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table")) - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table")) - // checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col5) FROM $table")) - } - } - } - } -} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0e7bc46837..0fc6ea35d0 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -95,11 +95,19 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { val table = "bitwise_count_test" withTable(table) { - sql(s"create table $table(col1 long) using parquet") - sql(s"insert into $table values(1111)") - sql(s"insert into $table values(1111)") - - checkSparkAnswer(sql(s"SELECT bit_count(col1) FROM $table")) + sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet") + sql(s"insert into $table values(1111, 2222, 17, 7)") + sql( + s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, ${Short.MaxValue}, ${Byte.MaxValue})") + sql( + s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, ${Short.MinValue}, ${Byte.MinValue})") + + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM $table")) } } } From 2e8acf81bdc04dabe12fdbdbaada14477e41e05e Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Wed, 2 Apr 2025 22:17:27 +0400 Subject: [PATCH 03/11] fix fmt --- native/core/src/execution/planner.rs | 8 +++---- .../src/bitwise_funcs/bitwise_count.rs | 21 ++++--------------- native/spark-expr/src/bitwise_funcs/mod.rs | 4 ++-- 3 files changed, 10 insertions(+), 23 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 19cb92a443..ecdbc6c2c6 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -103,10 +103,10 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, BitwiseCountExpr, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation, - Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField, - HourExpr, IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr, - SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, + ArrayInsert, Avg, AvgDecimal, BitwiseCountExpr, BitwiseNotExpr, Cast, CheckOverflow, Contains, + Correlation, Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, + GetStructField, HourExpr, IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, + SecondExpr, SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, }; use itertools::Itertools; diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index 9a43ab1f19..9566477b26 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -35,7 +35,7 @@ macro_rules! compute_op { let result: $DT = operand .iter() - .map(|x| { x.map(|y| { bit_count(y.into()) as $TY })}) + .map(|x| x.map(|y| bit_count(y.into()) as $TY)) .collect(); Ok(Arc::new(result)) @@ -158,21 +158,8 @@ mod tests { let expr = bitwise_count(col("field", &schema)?)?; - let input = Int32Array::from(vec![ - Some(1), - None, - Some(12345), - Some(89), - Some(-3456), - ]); - - let expected = &Int32Array::from(vec![ - Some(1), - None, - Some(6), - Some(4), - Some(54), - ]); + let input = Int32Array::from(vec![Some(1), None, Some(12345), Some(89), Some(-3456)]); + let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; @@ -182,4 +169,4 @@ mod tests { Ok(()) } -} \ No newline at end of file +} diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs index 4c75501e51..d04cfb832f 100644 --- a/native/spark-expr/src/bitwise_funcs/mod.rs +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -mod bitwise_not; mod bitwise_count; +mod bitwise_not; -pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; pub use bitwise_count::{bitwise_count, BitwiseCountExpr}; +pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; From 850510aa63a3c42e4cfa535ef591cb8857454058 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Mon, 28 Apr 2025 21:15:48 +0400 Subject: [PATCH 04/11] add tests with random values --- .../src/bitwise_funcs/bitwise_count.rs | 23 +++++++++------ .../apache/comet/CometExpressionSuite.scala | 29 ++++++++++++++++++- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index 9566477b26..4a202a975d 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -23,19 +23,20 @@ use arrow::{ use datafusion::common::Result; use datafusion::physical_expr::PhysicalExpr; use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; +use std::fmt::Formatter; use std::hash::Hash; use std::{any::Any, sync::Arc}; macro_rules! compute_op { - ($OPERAND:expr, $DT:ident, $TY:ty) => {{ + ($OPERAND:expr, $DT:ident) => {{ let operand = $OPERAND .as_any() .downcast_ref::<$DT>() .expect("compute_op failed to downcast array"); - let result: $DT = operand + let result: Int32Array = operand .iter() - .map(|x| x.map(|y| bit_count(y.into()) as $TY)) + .map(|x| x.map(|y| bit_count(y.into()))) .collect(); Ok(Arc::new(result)) @@ -85,8 +86,8 @@ impl PhysicalExpr for BitwiseCountExpr { self } - fn data_type(&self, input_schema: &Schema) -> Result { - self.arg.data_type(input_schema) + fn data_type(&self, _: &Schema) -> Result { + Ok(DataType::Int32) } fn nullable(&self, input_schema: &Schema) -> Result { @@ -98,10 +99,10 @@ impl PhysicalExpr for BitwiseCountExpr { match arg { ColumnarValue::Array(array) => { let result: Result = match array.data_type() { - DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array, i8), - DataType::Int16 => compute_op!(array, Int16Array, i16), - DataType::Int32 => compute_op!(array, Int32Array, i32), - DataType::Int64 => compute_op!(array, Int64Array, i64), + DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array), + DataType::Int16 => compute_op!(array, Int16Array), + DataType::Int32 => compute_op!(array, Int32Array), + DataType::Int64 => compute_op!(array, Int64Array), _ => Err(DataFusionError::Execution(format!( "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed int", self, @@ -126,6 +127,10 @@ impl PhysicalExpr for BitwiseCountExpr { ) -> Result> { Ok(Arc::new(BitwiseCountExpr::new(Arc::clone(&children[0])))) } + + fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } } pub fn bitwise_count(arg: Arc) -> Result> { diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0fc6ea35d0..ba6182e672 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types.{Decimal, DecimalType} import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -90,7 +91,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("bitwise_count") { + test("bitwise_count - min/max values") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { val table = "bitwise_count_test" @@ -113,6 +114,32 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("bitwise_count - random values") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 10, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + val df = + table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", "bit_count(c4)") + + checkSparkAnswerAndOperator(df) + } + } + test("bitwise shift with different left/right types") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { From bb1bb5fbab815b25e023a1db268aba00c4e1954a Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Thu, 1 May 2025 18:09:34 +0400 Subject: [PATCH 05/11] Fix PR issues --- native/core/src/execution/planner.rs | 12 +- native/proto/src/proto/expr.proto | 1 - .../src/bitwise_funcs/bitwise_count.rs | 140 +++++------------- native/spark-expr/src/bitwise_funcs/mod.rs | 2 +- native/spark-expr/src/comet_scalar_funcs.rs | 11 +- .../apache/comet/serde/QueryPlanSerde.scala | 10 +- 6 files changed, 50 insertions(+), 126 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 98604fb26c..56fe029302 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -102,10 +102,10 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, BitwiseCountExpr, BitwiseNotExpr, Cast, CheckOverflow, Contains, - Correlation, Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, - GetStructField, HourExpr, IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, - SecondExpr, SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, + ArrayInsert, Avg, AvgDecimal, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation, + Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField, + HourExpr, IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr, + SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, }; use itertools::Itertools; @@ -609,10 +609,6 @@ impl PhysicalPlanner { let op = DataFusionOperator::BitwiseShiftLeft; Ok(Arc::new(BinaryExpr::new(left, op, right))) } - ExprStruct::BitwiseCount(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(BitwiseCountExpr::new(child))) - } // https://github.com/apache/datafusion-comet/issues/666 // ExprStruct::Abs(expr) => { // let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 634900a089..90fd08948c 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -84,7 +84,6 @@ message Expr { GetArrayStructFields get_array_struct_fields = 57; ArrayInsert array_insert = 58; MathExpr integral_divide = 59; - UnaryExpr bitwiseCount = 60; } } diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index 4a202a975d..d402e406ec 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -15,17 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::*, - datatypes::{DataType, Schema}, - record_batch::RecordBatch, -}; +use arrow::{array::*, datatypes::DataType}; use datafusion::common::Result; -use datafusion::physical_expr::PhysicalExpr; use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; -use std::fmt::Formatter; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; macro_rules! compute_op { ($OPERAND:expr, $DT:ident) => {{ @@ -43,100 +36,32 @@ macro_rules! compute_op { }}; } -/// BitwiseCount expression -#[derive(Debug, Eq)] -pub struct BitwiseCountExpr { - /// Input expression - arg: Arc, -} - -impl Hash for BitwiseCountExpr { - fn hash(&self, state: &mut H) { - self.arg.hash(state); - } -} - -impl PartialEq for BitwiseCountExpr { - fn eq(&self, other: &Self) -> bool { - self.arg.eq(&other.arg) - } -} - -impl BitwiseCountExpr { - /// Create new bitwise count expression - pub fn new(arg: Arc) -> Self { - Self { arg } - } - - /// Get the input expression - pub fn arg(&self) -> &Arc { - &self.arg - } -} - -impl std::fmt::Display for BitwiseCountExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "(~ {})", self.arg) - } -} - -impl PhysicalExpr for BitwiseCountExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self +pub fn spark_bit_count(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "hex expects exactly one argument".to_string(), + )); } - - fn data_type(&self, _: &Schema) -> Result { - Ok(DataType::Int32) - } - - fn nullable(&self, input_schema: &Schema) -> Result { - self.arg.nullable(input_schema) - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - let arg = self.arg.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let result: Result = match array.data_type() { - DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array), - DataType::Int16 => compute_op!(array, Int16Array), - DataType::Int32 => compute_op!(array, Int32Array), - DataType::Int64 => compute_op!(array, Int64Array), - _ => Err(DataFusionError::Execution(format!( - "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed int", - self, - array.data_type(), - ))), - }; - result.map(ColumnarValue::Array) - } - ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( - "shouldn't go to bitwise count scalar path".to_string(), - )), + match &args[0] { + ColumnarValue::Array(array) => { + let result: Result = match array.data_type() { + DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array), + DataType::Int16 => compute_op!(array, Int16Array), + DataType::Int32 => compute_op!(array, Int32Array), + DataType::Int64 => compute_op!(array, Int64Array), + _ => Err(DataFusionError::Execution(format!( + "Can't be evaluated because the expression's type is {:?}, not signed int", + array.data_type(), + ))), + }; + result.map(ColumnarValue::Array) } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.arg] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(BitwiseCountExpr::new(Arc::clone(&children[0])))) - } - - fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { - unimplemented!() + ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( + "shouldn't go to bitwise count scalar path".to_string(), + )), } } -pub fn bitwise_count(arg: Arc) -> Result> { - Ok(Arc::new(BitwiseCountExpr::new(arg))) -} - // Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType) fn bit_count(i: i64) -> i32 { let mut u = i as u64; @@ -151,24 +76,25 @@ fn bit_count(i: i64) -> i32 { #[cfg(test)] mod tests { - use arrow::datatypes::*; use datafusion::common::{cast::as_int32_array, Result}; - use datafusion::physical_expr::expressions::col; use super::*; #[test] fn bitwise_count_op() -> Result<()> { - let schema = Schema::new(vec![Field::new("field", DataType::Int32, true)]); - - let expr = bitwise_count(col("field", &schema)?)?; - - let input = Int32Array::from(vec![Some(1), None, Some(12345), Some(89), Some(-3456)]); + let args = vec![ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(12345), + Some(89), + Some(-3456), + ])))]; let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; + let ColumnarValue::Array(result) = spark_bit_count(&args)? else { + unreachable!() + }; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; let result = as_int32_array(&result).expect("failed to downcast to In32Array"); assert_eq!(result, expected); diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs index d04cfb832f..718cfc7ca8 100644 --- a/native/spark-expr/src/bitwise_funcs/mod.rs +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -18,5 +18,5 @@ mod bitwise_count; mod bitwise_not; -pub use bitwise_count::{bitwise_count, BitwiseCountExpr}; +pub use bitwise_count::spark_bit_count; pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index f954fdd8ce..e85247548c 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -17,9 +17,10 @@ use crate::hash_funcs::*; use crate::{ - spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div, - spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, - spark_rpad, spark_unhex, spark_unscaled_value, SparkChrFunc, + spark_bit_count, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, + spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, + spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, + SparkChrFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -140,6 +141,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_date_sub); make_comet_scalar_udf!("date_sub", func, without data_type) } + "bit_count" => { + let func = Arc::new(spark_bit_count); + make_comet_scalar_udf!("bit_count", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cd78b61002..075fef45ec 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1654,12 +1654,10 @@ object QueryPlanSerde extends Logging with CometExprShim { (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) case BitwiseCount(child) => - createUnaryExpr( - expr, - child, - inputs, - binding, - (builder, unaryExpr) => builder.setBitwiseCount(unaryExpr)) + val childProto = exprToProto(child, inputs, binding) + val bitCountScalarExpr = + scalarExprToProtoWithReturnType("bit_count", IntegerType, childProto) + optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*) case ShiftRight(left, right) => // DataFusion bitwise shift right expression requires From 50bf80c4fe9f60bc57fab5d9b58d1c9d971ac561 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Thu, 1 May 2025 18:15:27 +0400 Subject: [PATCH 06/11] Resolve conflicts --- native/spark-expr/src/comet_scalar_funcs.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 40fa0d6466..f852060008 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -17,10 +17,10 @@ use crate::hash_funcs::*; use crate::{ - spark_array_repeat, spark_bit_count, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, - spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, - spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, - SparkChrFunc, + spark_array_repeat, spark_bit_count, spark_ceil, spark_date_add, spark_date_sub, + spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, + spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, + spark_unscaled_value, SparkChrFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; From c221cc4d241d7d53de6415d0273b2ab5ae2193db Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Thu, 1 May 2025 18:39:14 +0400 Subject: [PATCH 07/11] Resolve conflicts --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 87337bf96e..f4947706a3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1658,7 +1658,7 @@ object QueryPlanSerde extends Logging with CometExprShim { case BitwiseCount(child) => val childProto = exprToProto(child, inputs, binding) val bitCountScalarExpr = - scalarExprToProtoWithReturnType("bit_count", IntegerType, childProto) + scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto) optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*) case ShiftRight(left, right) => From 9c9b49a7844f20c6e38f411c975ba01563fe667b Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Thu, 1 May 2025 21:34:23 +0400 Subject: [PATCH 08/11] Micro refactoring --- native/spark-expr/src/bitwise_funcs/bitwise_count.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index d402e406ec..d77105f81e 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -39,7 +39,7 @@ macro_rules! compute_op { pub fn spark_bit_count(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return Err(DataFusionError::Internal( - "hex expects exactly one argument".to_string(), + "bit_count expects exactly one argument".to_string(), )); } match &args[0] { @@ -57,7 +57,7 @@ pub fn spark_bit_count(args: &[ColumnarValue]) -> Result { result.map(ColumnarValue::Array) } ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( - "shouldn't go to bitwise count scalar path".to_string(), + "shouldn't go to bit_count scalar path".to_string(), )), } } From e8eea56d4bb7bc02a030fb646762ed49b96d8f4a Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Wed, 14 May 2025 09:47:16 +0400 Subject: [PATCH 09/11] Fix PR issues --- .../src/bitwise_funcs/bitwise_count.rs | 10 +++++---- .../apache/comet/CometExpressionSuite.scala | 21 ++++++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index d77105f81e..f0a1b00737 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -22,10 +22,12 @@ use std::sync::Arc; macro_rules! compute_op { ($OPERAND:expr, $DT:ident) => {{ - let operand = $OPERAND - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast array"); + let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| { + DataFusionError::Execution(format!( + "compute_op failed to downcast array to: {:?}", + stringify!($DT) + )) + })?; let result: Int32Array = operand .iter() diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index d775dfab6e..f6bf4c32ce 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -114,7 +114,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("bitwise_count - random values") { + test("bitwise_count - random values (spark gen)") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") val filename = path.toString @@ -140,6 +140,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("bitwise_count - random values (native parquet gen)") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 0, 10000, nullEnabled = false) + val table = spark.read.parquet(path.toString) + checkSparkAnswerAndOperator( + table + .selectExpr( + s"bit_count(_2)", + s"bit_count(_3)", + s"bit_count(_4)", + s"bit_count(_5)", + s"bit_count(_10)", + s"bit_count(_11)")) + } + } + } + test("bitwise shift with different left/right types") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { From 6c09a10ac3781b2c16eaab8532c588fbb35644f4 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Wed, 14 May 2025 16:39:25 +0400 Subject: [PATCH 10/11] Fix format --- .../org/apache/comet/CometExpressionSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f6bf4c32ce..b9c699cc60 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -149,12 +149,12 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator( table .selectExpr( - s"bit_count(_2)", - s"bit_count(_3)", - s"bit_count(_4)", - s"bit_count(_5)", - s"bit_count(_10)", - s"bit_count(_11)")) + "bit_count(_2)", + "bit_count(_3)", + "bit_count(_4)", + "bit_count(_5)", + "bit_count(_10)", + "bit_count(_11)")) } } } From 8e65528419ce17f87e5b218e9d756229dcc1de02 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Fri, 23 May 2025 21:41:43 +0400 Subject: [PATCH 11/11] Fix tests --- spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index b9c699cc60..7fc22357ec 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -153,7 +153,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { "bit_count(_3)", "bit_count(_4)", "bit_count(_5)", - "bit_count(_10)", "bit_count(_11)")) } }