Skip to content

Commit 53c724e

Browse files
Feat: support bit_count function (#1602)
## Which issue does this PR close? Related to Epic: #240 bit_count: SELECT bit_count(0) => 0 DataFusionComet bit_count has same behavior with Spark 's bit_count function Spark: https://spark.apache.org/docs/latest/api/sql/index.html#bit_count Closes #. ## Rationale for this change Defined under Epic: #240 ## What changes are included in this PR? bitwise_count.rs: impl for bit_count function planner.rs: Maps Spark 's bit_count function to DataFusionComet bit_count physical expression from Spark physical expression expr.proto: bit_count has been added, QueryPlanSerde.scala: bit_count pattern matching case has been added, CometExpressionSuite.scala: A new UT has been added for bit_count function. ## How are these changes tested? A new UT has been added.
1 parent 2ce969e commit 53c724e

File tree

5 files changed

+189
-4
lines changed

5 files changed

+189
-4
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::{array::*, datatypes::DataType};
19+
use datafusion::common::Result;
20+
use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
21+
use std::sync::Arc;
22+
23+
macro_rules! compute_op {
24+
($OPERAND:expr, $DT:ident) => {{
25+
let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| {
26+
DataFusionError::Execution(format!(
27+
"compute_op failed to downcast array to: {:?}",
28+
stringify!($DT)
29+
))
30+
})?;
31+
32+
let result: Int32Array = operand
33+
.iter()
34+
.map(|x| x.map(|y| bit_count(y.into())))
35+
.collect();
36+
37+
Ok(Arc::new(result))
38+
}};
39+
}
40+
41+
pub fn spark_bit_count(args: &[ColumnarValue]) -> Result<ColumnarValue> {
42+
if args.len() != 1 {
43+
return Err(DataFusionError::Internal(
44+
"bit_count expects exactly one argument".to_string(),
45+
));
46+
}
47+
match &args[0] {
48+
ColumnarValue::Array(array) => {
49+
let result: Result<ArrayRef> = match array.data_type() {
50+
DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array),
51+
DataType::Int16 => compute_op!(array, Int16Array),
52+
DataType::Int32 => compute_op!(array, Int32Array),
53+
DataType::Int64 => compute_op!(array, Int64Array),
54+
_ => Err(DataFusionError::Execution(format!(
55+
"Can't be evaluated because the expression's type is {:?}, not signed int",
56+
array.data_type(),
57+
))),
58+
};
59+
result.map(ColumnarValue::Array)
60+
}
61+
ColumnarValue::Scalar(_) => Err(DataFusionError::Internal(
62+
"shouldn't go to bit_count scalar path".to_string(),
63+
)),
64+
}
65+
}
66+
67+
// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType)
68+
fn bit_count(i: i64) -> i32 {
69+
let mut u = i as u64;
70+
u = u - ((u >> 1) & 0x5555555555555555);
71+
u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333);
72+
u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f;
73+
u = u + (u >> 8);
74+
u = u + (u >> 16);
75+
u = u + (u >> 32);
76+
(u as i32) & 0x7f
77+
}
78+
79+
#[cfg(test)]
80+
mod tests {
81+
use datafusion::common::{cast::as_int32_array, Result};
82+
83+
use super::*;
84+
85+
#[test]
86+
fn bitwise_count_op() -> Result<()> {
87+
let args = vec![ColumnarValue::Array(Arc::new(Int32Array::from(vec![
88+
Some(1),
89+
None,
90+
Some(12345),
91+
Some(89),
92+
Some(-3456),
93+
])))];
94+
let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]);
95+
96+
let ColumnarValue::Array(result) = spark_bit_count(&args)? else {
97+
unreachable!()
98+
};
99+
100+
let result = as_int32_array(&result).expect("failed to downcast to In32Array");
101+
assert_eq!(result, expected);
102+
103+
Ok(())
104+
}
105+
}

native/spark-expr/src/bitwise_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod bitwise_count;
1819
mod bitwise_not;
1920

21+
pub use bitwise_count::spark_bit_count;
2022
pub use bitwise_not::{bitwise_not, BitwiseNotExpr};

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
use crate::hash_funcs::*;
1919
use crate::{
20-
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
21-
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
22-
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
23-
SparkChrFunc,
20+
spark_array_repeat, spark_bit_count, spark_ceil, spark_date_add, spark_date_sub,
21+
spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan,
22+
spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
23+
spark_unscaled_value, SparkChrFunc,
2424
};
2525
use arrow::datatypes::DataType;
2626
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -145,6 +145,10 @@ pub fn create_comet_physical_fun(
145145
let func = Arc::new(spark_array_repeat);
146146
make_comet_scalar_udf!("array_repeat", func, without data_type)
147147
}
148+
"bit_count" => {
149+
let func = Arc::new(spark_bit_count);
150+
make_comet_scalar_udf!("bit_count", func, without data_type)
151+
}
148152
_ => registry.udf(fun_name).map_err(|e| {
149153
DataFusionError::Execution(format!(
150154
"Function {fun_name} not found in the registry: {e}",

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
16341634
binding,
16351635
(builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
16361636

1637+
case BitwiseCount(child) =>
1638+
val childProto = exprToProto(child, inputs, binding)
1639+
val bitCountScalarExpr =
1640+
scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto)
1641+
optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*)
1642+
16371643
case ShiftRight(left, right) =>
16381644
// DataFusion bitwise shift right expression requires
16391645
// same data type between left and right side

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
3737
import org.apache.spark.sql.types.{Decimal, DecimalType}
3838

3939
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
40+
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
4041

4142
class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
4243
import testImplicits._
@@ -99,6 +100,73 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
99100
}
100101
}
101102

103+
test("bitwise_count - min/max values") {
104+
Seq(false, true).foreach { dictionary =>
105+
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
106+
val table = "bitwise_count_test"
107+
withTable(table) {
108+
sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet")
109+
sql(s"insert into $table values(1111, 2222, 17, 7)")
110+
sql(
111+
s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, ${Short.MaxValue}, ${Byte.MaxValue})")
112+
sql(
113+
s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, ${Short.MinValue}, ${Byte.MinValue})")
114+
115+
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table"))
116+
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table"))
117+
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table"))
118+
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table"))
119+
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM $table"))
120+
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM $table"))
121+
}
122+
}
123+
}
124+
}
125+
126+
test("bitwise_count - random values (spark gen)") {
127+
withTempDir { dir =>
128+
val path = new Path(dir.toURI.toString, "test.parquet")
129+
val filename = path.toString
130+
val random = new Random(42)
131+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
132+
ParquetGenerator.makeParquetFile(
133+
random,
134+
spark,
135+
filename,
136+
10,
137+
DataGenOptions(
138+
allowNull = true,
139+
generateNegativeZero = true,
140+
generateArray = false,
141+
generateStruct = false,
142+
generateMap = false))
143+
}
144+
val table = spark.read.parquet(filename)
145+
val df =
146+
table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", "bit_count(c4)")
147+
148+
checkSparkAnswerAndOperator(df)
149+
}
150+
}
151+
152+
test("bitwise_count - random values (native parquet gen)") {
153+
Seq(true, false).foreach { dictionaryEnabled =>
154+
withTempDir { dir =>
155+
val path = new Path(dir.toURI.toString, "test.parquet")
156+
makeParquetFileAllTypes(path, dictionaryEnabled, 0, 10000, nullEnabled = false)
157+
val table = spark.read.parquet(path.toString)
158+
checkSparkAnswerAndOperator(
159+
table
160+
.selectExpr(
161+
"bit_count(_2)",
162+
"bit_count(_3)",
163+
"bit_count(_4)",
164+
"bit_count(_5)",
165+
"bit_count(_11)"))
166+
}
167+
}
168+
}
169+
102170
test("bitwise shift with different left/right types") {
103171
Seq(false, true).foreach { dictionary =>
104172
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {

0 commit comments

Comments
 (0)