diff --git a/.github/workflows/pr_build.yml b/.github/workflows/pr_build.yml index ce81ea2b24..9aa12206bc 100644 --- a/.github/workflows/pr_build.yml +++ b/.github/workflows/pr_build.yml @@ -75,6 +75,8 @@ jobs: maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} # upload test reports only for java 17 upload-test-reports: ${{ matrix.java_version == '17' }} + env: + COMET_FUZZ_TEST: "true" linux-test-with-spark4_0: strategy: @@ -102,6 +104,8 @@ jobs: with: maven_opts: -Pspark-${{ matrix.spark-version }} upload-test-reports: true + env: + COMET_FUZZ_TEST: "true" linux-test-with-old-spark: strategy: @@ -127,6 +131,8 @@ jobs: uses: ./.github/actions/java-test with: maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} + env: + COMET_FUZZ_TEST: "true" macos-test: strategy: @@ -155,6 +161,8 @@ jobs: uses: ./.github/actions/java-test with: maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} + env: + COMET_FUZZ_TEST: "true" macos-aarch64-test: strategy: @@ -188,6 +196,8 @@ jobs: uses: ./.github/actions/java-test with: maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} + env: + COMET_FUZZ_TEST: "true" macos-test-with-spark4_0: strategy: @@ -212,6 +222,8 @@ jobs: with: maven_opts: -Pspark-${{ matrix.spark-version }} upload-test-reports: true + env: + COMET_FUZZ_TEST: "true" macos-aarch64-test-with-spark4_0: strategy: @@ -241,6 +253,8 @@ jobs: with: maven_opts: -Pspark-${{ matrix.spark-version }} upload-test-reports: true + env: + COMET_FUZZ_TEST: "true" macos-aarch64-test-with-old-spark: strategy: @@ -269,4 +283,5 @@ jobs: uses: ./.github/actions/java-test with: maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} - + env: + COMET_FUZZ_TEST: "true" diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index aa1aba11dc..9688ba5765 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -73,6 +73,29 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim false } + def isMatch(dt: DataType, at: ArgType): Boolean = { + at match { + case AnyType => true + case IntegralType => + dt match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType => true + case _ => false + } + case NumericType => + dt match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType => true + case _: FloatType | _: DoubleType => true + case _: DecimalType => true + case _ => false + } + case OrderedType => + // TODO exclude map or other complex types that contain maps + true + case _ => + false + } + } + /** * Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method * doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return @@ -377,6 +400,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim return None } + + // check that Comet supports the input data types, using the same logic that is leveraged + // in fuzz testing + cometExpr.getSignature() match { + case Fixed(dataTypes) => + if (aggExpr.children.length != dataTypes.length) { + withInfo(aggExpr, "Unsupported input argument count") + return None + } + val supportedTypes = dataTypes.zip(aggExpr.children.map(_.dataType)).forall { + case (expected, provided) => isMatch(provided, expected) + } + if (!supportedTypes) { + withInfo(aggExpr, "Unsupported input types") + return None + } + case _ => + } + cometExpr.convert(aggExpr, aggExpr.aggregateFunction, inputs, binding, conf) } @@ -3011,6 +3053,10 @@ trait CometExpressionSerde { */ trait CometAggregateExpressionSerde { + def sql(): String + + def getSignature(): Signature + /** * Convert a Spark expression into a protocol buffer representation that can be passed into * native code. @@ -3040,3 +3086,40 @@ trait CometAggregateExpressionSerde { /** Marker trait for an expression that is not guaranteed to be 100% compatible with Spark */ trait IncompatExpr {} + +/** Represents the data type(s) that an argument accepts */ +sealed trait ArgType + +/** Supports any input type */ +case object AnyType extends ArgType + +/** Integral, floating-point, and decimal */ +case object NumericType extends ArgType + +/** Integral types (byte, short, int, long) */ +case object IntegralType extends ArgType + +/** Types that can ordered. Includes struct and array but excludes maps */ +case object OrderedType extends ArgType + +/* +case class ConcreteTypes(dataTypes: Seq[DataType]) extends ArgType + */ + +// Base trait for expression signatures +trait Signature + +// A fixed number of arguments with specific types +case class Fixed(types: Seq[ArgType]) extends Signature + +/* +// A mix of fixed and optional arguments +case class FixedWithOptional(fixed: Seq[ArgType], optional: Seq[ArgType]) extends Signature + +// A variadic signature, allowing for a range of arguments +case class Variadic(minArgs: Option[Int], maxArgs: Option[Int], argType: ArgType) + extends Signature + +// A generic function signature that supports multiple forms +case class Overloaded(variants: Seq[Signature]) extends Signature + */ diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index da5e9ff534..e89b143e89 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -33,6 +33,10 @@ import org.apache.comet.shims.ShimQueryPlanSerde object CometMin extends CometAggregateExpressionSerde { + override def sql(): String = "min" + + override def getSignature(): Signature = Fixed(Seq(OrderedType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -69,6 +73,10 @@ object CometMin extends CometAggregateExpressionSerde { object CometMax extends CometAggregateExpressionSerde { + override def sql(): String = "max" + + override def getSignature(): Signature = Fixed(Seq(OrderedType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -104,6 +112,11 @@ object CometMax extends CometAggregateExpressionSerde { } object CometCount extends CometAggregateExpressionSerde { + + override def sql(): String = "count" + + override def getSignature(): Signature = Fixed(Seq(AnyType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -127,6 +140,11 @@ object CometCount extends CometAggregateExpressionSerde { } object CometAverage extends CometAggregateExpressionSerde with ShimQueryPlanSerde { + + override def sql(): String = "avg" + + override def getSignature(): Signature = Fixed(Seq(NumericType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -182,6 +200,11 @@ object CometAverage extends CometAggregateExpressionSerde with ShimQueryPlanSerd } } object CometSum extends CometAggregateExpressionSerde with ShimQueryPlanSerde { + + override def sql(): String = "sum" + + override def getSignature(): Signature = Fixed(Seq(NumericType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -226,6 +249,11 @@ object CometSum extends CometAggregateExpressionSerde with ShimQueryPlanSerde { } object CometFirst extends CometAggregateExpressionSerde { + + override def sql(): String = "first" + + override def getSignature(): Signature = Fixed(Seq(OrderedType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -263,6 +291,11 @@ object CometFirst extends CometAggregateExpressionSerde { } object CometLast extends CometAggregateExpressionSerde { + + override def sql(): String = "last" + + override def getSignature(): Signature = Fixed(Seq(OrderedType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -300,6 +333,11 @@ object CometLast extends CometAggregateExpressionSerde { } object CometBitAndAgg extends CometAggregateExpressionSerde { + + override def sql(): String = "bit_and" + + override def getSignature(): Signature = Fixed(Seq(IntegralType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -335,6 +373,11 @@ object CometBitAndAgg extends CometAggregateExpressionSerde { } object CometBitOrAgg extends CometAggregateExpressionSerde { + + override def sql(): String = "bit_or" + + override def getSignature(): Signature = Fixed(Seq(IntegralType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -370,6 +413,11 @@ object CometBitOrAgg extends CometAggregateExpressionSerde { } object CometBitXOrAgg extends CometAggregateExpressionSerde { + + override def sql(): String = "bit_xor" + + override def getSignature(): Signature = Fixed(Seq(IntegralType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -405,6 +453,9 @@ object CometBitXOrAgg extends CometAggregateExpressionSerde { } trait CometCovBase extends CometAggregateExpressionSerde { + + override def getSignature(): Signature = Fixed(Seq(NumericType, NumericType)) + def convertCov( aggExpr: AggregateExpression, cov: Covariance, @@ -438,6 +489,9 @@ trait CometCovBase extends CometAggregateExpressionSerde { } object CometCovSample extends CometCovBase { + + override def sql(): String = "covar_samp" + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -457,6 +511,9 @@ object CometCovSample extends CometCovBase { } object CometCovPopulation extends CometCovBase { + + override def sql(): String = "covar_pop" + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -476,6 +533,9 @@ object CometCovPopulation extends CometCovBase { } trait CometVariance extends CometAggregateExpressionSerde { + + override def getSignature(): Signature = Fixed(Seq(NumericType)) + def convertVariance( aggExpr: AggregateExpression, expr: CentralMomentAgg, @@ -507,6 +567,9 @@ trait CometVariance extends CometAggregateExpressionSerde { } object CometVarianceSamp extends CometVariance { + + override def sql(): String = "var_samp" + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -519,6 +582,9 @@ object CometVarianceSamp extends CometVariance { } object CometVariancePop extends CometVariance { + + override def sql(): String = "var_pop" + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -531,6 +597,9 @@ object CometVariancePop extends CometVariance { } trait CometStddev extends CometAggregateExpressionSerde { + + override def getSignature(): Signature = Fixed(Seq(NumericType)) + def convertStddev( aggExpr: AggregateExpression, stddev: CentralMomentAgg, @@ -572,6 +641,11 @@ trait CometStddev extends CometAggregateExpressionSerde { } object CometStddevSamp extends CometStddev { + + override def sql(): String = "stddev_samp" + + override def getSignature(): Signature = Fixed(Seq(NumericType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -591,6 +665,11 @@ object CometStddevSamp extends CometStddev { } object CometStddevPop extends CometStddev { + + override def sql(): String = "stddev_pop" + + override def getSignature(): Signature = Fixed(Seq(NumericType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -610,6 +689,11 @@ object CometStddevPop extends CometStddev { } object CometCorr extends CometAggregateExpressionSerde { + + override def sql(): String = "corr" + + override def getSignature(): Signature = Fixed(Seq(NumericType, NumericType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, @@ -642,6 +726,10 @@ object CometCorr extends CometAggregateExpressionSerde { object CometBloomFilterAggregate extends CometAggregateExpressionSerde { + override def sql(): String = "bloom_filter_agg" + + override def getSignature(): Signature = Fixed(Seq(AnyType)) + override def convert( aggExpr: AggregateExpression, expr: Expression, diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala new file mode 100644 index 0000000000..b9a0e68fdc --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.comet.CometFuzzTestSuite.aggregateExpressions + +import scala.util.Random +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.types.{DecimalType, StructField} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometAverage, CometCorr, CometCount, CometCovPopulation, CometCovSample, CometMax, CometMin, CometStddevPop, CometStddevSamp, CometSum, CometVariancePop, CometVarianceSamp, Fixed, NumericType} +import org.apache.comet.serde.QueryPlanSerde.isMatch +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} + +class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + private val fuzzTestEnabled: Boolean = sys.env.contains("COMET_FUZZ_TEST") + + test("aggregates") { + assume(fuzzTestEnabled) + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile(random, spark, filename, 10000, DataGenOptions()) + } + val table = spark.read.parquet(filename).coalesce(1) + table.createOrReplaceTempView("t1") + + val groupingFields: Array[StructField] = + table.schema.fields.filterNot(f => isMatch(f.dataType, NumericType)) + + // test grouping by each non-numeric column, grouping by all non-numeric columns, and no grouping + val groupByIndividualCols: Seq[Seq[String]] = groupingFields.map(f => Seq(f.name)).toSeq + val groupByAllCols: Seq[Seq[String]] = Seq(groupingFields.map(_.name).toSeq) + val noGroup: Seq[Seq[String]] = Seq(Seq.empty) + val groupings: Seq[Seq[String]] = groupByIndividualCols ++ groupByAllCols ++ noGroup + + val scanTypes = Seq( + + // TODO enable ParquetExec-based scans once they are complete + CometConf.SCAN_NATIVE_COMET + /*CometConf.SCAN_NATIVE_DATAFUSION, + CometConf.SCAN_NATIVE_ICEBERG_COMPAT*/ ) + + for (scan <- scanTypes) { + for (shuffleMode <- Seq("auto", "jvm", "native")) { + withSQLConf( + CometConf.COMET_NATIVE_SCAN_IMPL.key -> scan, + CometConf.COMET_SHUFFLE_MODE.key -> shuffleMode) { + for (group <- groupings) { + for (agg <- aggregateExpressions) { + agg.getSignature() match { + case Fixed(dataTypes) => + // pick all compatible columns for all input args + val argFields: Seq[Array[StructField]] = dataTypes.map(argType => + table.schema.fields.filter(f => isMatch(f.dataType, argType))) + + // TODO: just pick the first compatible column for each type for now, but + // should randomize this or test all combinations + val args: Seq[StructField] = argFields.map(_.head) + + val aggSql = s"${agg.sql()}(${args.map(_.name).mkString(",")})" + + val sql = if (group.isEmpty) { + s"SELECT $aggSql FROM t1" + } else { + val groupCols = group.mkString(", ") + s"SELECT $groupCols, $aggSql FROM t1 GROUP BY $groupCols ORDER BY $groupCols" + } + + // helps with debugging + println(sql) + + try { + checkSparkAnswerAndOperatorWithTol(sql) + } catch { + case e: Throwable => + logError(s"Fuzz test failed for query: $sql") + throw e + } + + case other => + // not supported by fuzz testing yet + logWarning(s"Fuzz test skipped agg '${agg.sql()}' due to unsupported signature: $other") + } + } + } + } + } + } + } + } + +} + +object CometFuzzTestSuite { + + /** + * Aggregate expressions. Note that `first` and `last` are excluded because they are + * non-deterministic. + */ + val aggregateExpressions: Seq[CometAggregateExpressionSerde] = Seq( + CometMin, + CometMax, + CometSum, + CometAverage, + CometCount, + // TODO: CometStddev, + CometStddevPop, + CometStddevSamp, + // TODO: CometVariance, + CometVarianceSamp, + CometVariancePop + // TODO: CometCorr, + // TODO: CometCovSample, + // TODO: CometCovPopulation + ) +} diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 63763aa3b8..600b5059f9 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -116,12 +116,51 @@ abstract class CometTestBase require(absTol > 0 && absTol <= 1e-6, s"absTol $absTol is out of range (0, 1e-6]") actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { + case (actual: Float, expected: Float) => + def isPosInfinity(value: Float): Boolean = { + // account for difference between Java and Rust + value.isPosInfinity || value == 3.4028235e38 + } + + if ((actual.isNaN && expected.isNaN) || + (isPosInfinity(actual) && isPosInfinity(expected)) || + (actual.isNegInfinity && expected.isNegInfinity)) { + // ok + } else { + + def almostEqual(a: Float, b: Float, tolerance: Float = 1e-6f): Boolean = { + Math.abs(a - b) <= tolerance * Math.max(Math.abs(a), Math.abs(b)) + } + + assert( + almostEqual(actual, expected), + s"actual answer $actual not within $absTol of correct answer $expected") + } + case (actual: Double, expected: Double) => - if (!actual.isNaN && !expected.isNaN) { + def isPosInfinity(value: Double): Boolean = { + // account for difference between Java and Rust + value.isPosInfinity || value == 1.7976931348623157e308 + } + + if ((actual.isNaN && expected.isNaN) || + (isPosInfinity(actual) && isPosInfinity(expected)) || + (actual.isNegInfinity && expected.isNegInfinity)) { + // ok + } else { + + def almostEqual(a: Double, b: Double, tolerance: Double = 1e-6f): Boolean = { + Math.abs(a - b) <= tolerance * Math.max(Math.abs(a), Math.abs(b)) + } + assert( - math.abs(actual - expected) < absTol, + almostEqual(actual, expected), s"actual answer $actual not within $absTol of correct answer $expected") } + + case (actual: Array[_], expected: Array[_]) => + assert(actual.sameElements(expected), s"$actualAnswer did not equal $expectedAnswer") + case (actual, expected) => assert(actual == expected, s"$actualAnswer did not equal $expectedAnswer") } @@ -229,6 +268,27 @@ abstract class CometTestBase checkAnswerWithTol(dfComet, expected, absTol: Double) } + /** + * Check the answer of a Comet SQL query with Spark result using absolute tolerance. + */ + protected def checkSparkAnswerAndOperatorWithTol(query: String, absTol: Double = 1e-6): Unit = { + checkSparkAnswerAndOperatorWithTol(sql(query), absTol) + } + + /** + * Check the answer of a Comet DataFrame with Spark result using absolute tolerance. + */ + protected def checkSparkAnswerAndOperatorWithTol(df: => DataFrame, absTol: Double): Unit = { + var expected: Array[Row] = Array.empty + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val dfSpark = Dataset.ofRows(spark, df.logicalPlan) + expected = dfSpark.collect() + } + val dfComet = Dataset.ofRows(spark, df.logicalPlan) + checkAnswerWithTol(dfComet, expected, absTol: Double) + checkCometOperators(stripAQEPlan(dfComet.queryExecution.executedPlan)) + } + protected def checkSparkMaybeThrows( df: => DataFrame): (Option[Throwable], Option[Throwable]) = { var expected: Option[Throwable] = None