From aab97cbfa405189b97186c1029bdb1620c95289f Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Tue, 6 Aug 2024 09:59:03 -0500 Subject: [PATCH 1/4] update rust deps --- Cargo.lock | 164 +++++++++++++++++++++++++++++------------------------ Cargo.toml | 10 ++-- 2 files changed, 95 insertions(+), 79 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c41ef771a..e59811210 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6127ea5e585a12ec9f742232442828ebaf264dfa5eefdd71282376c599562b77" +checksum = "05048a8932648b63f21c37d88b552ccc8a65afb6dfe9fc9f30ce79174c2e7a85" dependencies = [ "arrow-arith", "arrow-array", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7add7f39210b7d726e2a8efc0083e7bf06e8f2d15bdb4896b564dce4410fbf5d" +checksum = "1d8a57966e43bfe9a3277984a14c24ec617ad874e4c0e1d2a1b083a39cfbf22c" dependencies = [ "arrow-array", "arrow-buffer", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81c16ec702d3898c2f5cfdc148443c6cd7dbe5bac28399859eb0a3d38f072827" +checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" dependencies = [ "ahash", "arrow-buffer", @@ -184,9 +184,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cae6970bab043c4fbc10aee1660ceb5b306d0c42c8cc5f6ae564efcd9759b663" +checksum = "c975484888fc95ec4a632cdc98be39c085b1bb518531b0c80c5d462063e5daa1" dependencies = [ "bytes", "half", @@ -195,9 +195,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c7ef44f26ef4f8edc392a048324ed5d757ad09135eff6d5509e6450d39e0398" +checksum = "da26719e76b81d8bc3faad1d4dbdc1bcc10d14704e63dc17fc9f3e7e1e567c8e" dependencies = [ "arrow-array", "arrow-buffer", @@ -216,9 +216,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f843490bd258c5182b66e888161bb6f198f49f3792f7c7f98198b924ae0f564" +checksum = "c13c36dc5ddf8c128df19bab27898eea64bf9da2b555ec1cd17a8ff57fba9ec2" dependencies = [ "arrow-array", "arrow-buffer", @@ -235,9 +235,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a769666ffac256dd301006faca1ca553d0ae7cffcf4cd07095f73f95eb226514" +checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -247,9 +247,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf9c3fb57390a1af0b7bb3b5558c1ee1f63905f3eccf49ae7676a8d1e6e5a72" +checksum = "e786e1cdd952205d9a8afc69397b317cfbb6e0095e445c69cda7e8da5c1eeb0f" dependencies = [ "arrow-array", "arrow-buffer", @@ -262,9 +262,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "654e7f3724176b66ddfacba31af397c48e106fbe4d281c8144e7d237df5acfd7" +checksum = "fb22284c5a2a01d73cebfd88a33511a3234ab45d66086b2ca2d1228c3498e445" dependencies = [ "arrow-array", "arrow-buffer", @@ -282,9 +282,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8008370e624e8e3c68174faaf793540287106cfda8ad1da862fdc53d8e096b4" +checksum = "42745f86b1ab99ef96d1c0bcf49180848a64fe2c7a7a0d945bc64fa2b21ba9bc" dependencies = [ "arrow-array", "arrow-buffer", @@ -297,9 +297,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca5e3a6b7fda8d9fe03f3b18a2d946354ea7f3c8e4076dbdb502ad50d9d44824" +checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" dependencies = [ "ahash", "arrow-array", @@ -307,23 +307,22 @@ dependencies = [ "arrow-data", "arrow-schema", "half", - "hashbrown", ] [[package]] name = "arrow-schema" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dab1c12b40e29d9f3b699e0203c2a73ba558444c05e388a4377208f8f9c97eee" +checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" dependencies = [ "bitflags 2.6.0", ] [[package]] name = "arrow-select" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e80159088ffe8c48965cb9b1a7c968b2729f29f37363df7eca177fc3281fe7c3" +checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" dependencies = [ "ahash", "arrow-array", @@ -335,9 +334,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fd04a6ea7de183648edbcb7a6dd925bbd04c210895f6384c780e27a9b54afcd" +checksum = "f0dc1985b67cb45f6606a248ac2b4a288849f196bab8c657ea5589f47cdd55e6" dependencies = [ "arrow-array", "arrow-buffer", @@ -365,7 +364,7 @@ dependencies = [ "tokio", "xz2", "zstd 0.13.2", - "zstd-safe 7.2.0", + "zstd-safe 7.2.1", ] [[package]] @@ -516,9 +515,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.1" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "bzip2" @@ -543,9 +542,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" +checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" dependencies = [ "jobserver", "libc", @@ -1163,9 +1162,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920" dependencies = [ "crc32fast", "miniz_oxide", @@ -1458,9 +1457,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", @@ -1511,9 +1510,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" dependencies = [ "equivalent", "hashbrown", @@ -1995,9 +1994,9 @@ dependencies = [ [[package]] name = "parquet" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f22ba0d95db56dde8685e3fadcb915cdaadda31ab8abbe3ff7f0ad1ef333267" +checksum = "e977b9066b4d3b03555c22bdc442f3fadebd96a39111249113087d0edb2691cd" dependencies = [ "ahash", "arrow-array", @@ -2181,9 +2180,12 @@ checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "prettyplease" @@ -2347,9 +2349,9 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad" +checksum = "b22d8e7369034b9a7132bc2008cac12f2013c8132b45e0554e6e20e2617f2156" dependencies = [ "bytes", "pin-project-lite", @@ -2357,6 +2359,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", + "socket2", "thiserror", "tokio", "tracing", @@ -2364,9 +2367,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.3" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe" +checksum = "ba92fb39ec7ad06ca2582c0ca834dfeadcaf06ddfc8e635c80aa7e1c05315fdd" dependencies = [ "bytes", "rand", @@ -2388,6 +2391,7 @@ dependencies = [ "libc", "once_cell", "socket2", + "tracing", "windows-sys 0.52.0", ] @@ -2441,9 +2445,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", @@ -2558,9 +2562,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustc-hash" -version = "1.1.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" [[package]] name = "rustc_version" @@ -2613,9 +2617,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ "base64 0.22.1", "rustls-pki-types", @@ -2623,9 +2627,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" @@ -2769,11 +2773,12 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -3020,18 +3025,19 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "target-lexicon" -version = "0.12.15" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "b8fcd239983515c23a32fb82099f97d0b11b8c72f654ed659363a95c3dad7a53" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", "windows-sys 0.52.0", ] @@ -3093,9 +3099,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.1" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", @@ -3363,9 +3369,9 @@ dependencies = [ [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "walkdir" @@ -3483,11 +3489,11 @@ dependencies = [ [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3517,6 +3523,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -3663,6 +3678,7 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] @@ -3698,7 +3714,7 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" dependencies = [ - "zstd-safe 7.2.0", + "zstd-safe 7.2.1", ] [[package]] @@ -3713,18 +3729,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.2.0" +version = "7.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa556e971e7b568dc775c136fc9de8c779b1c2fc3a63defaafadffdbd3181afa" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.11+zstd.1.5.6" +version = "2.0.12+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75652c55c0b6f3e6f12eb786fe1bc960396bf05a1eb3bf1f3691c3610ac2e6d4" +checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index d05a617a3..820118fa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ protoc = [ "datafusion-substrait/protoc" ] substrait = ["dep:datafusion-substrait"] [dependencies] -tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.8" pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] } arrow = { version = "52", feature = ["pyarrow"] } @@ -45,17 +45,17 @@ datafusion-functions-array = "40.0.0" datafusion-optimizer = "40.0.0" datafusion-sql = "40.0.0" datafusion-substrait = { version = "40.0.0", optional = true } -prost = "0.12" -prost-types = "0.12" +prost = "0.12" # keep in line with `datafusion-substrait` +prost-types = "0.12" # keep in line with `datafusion-substrait` uuid = { version = "1.9", features = ["v4"] } mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] } async-trait = "0.1" futures = "0.3" object_store = { version = "0.10.1", features = ["aws", "gcp", "azure"] } parking_lot = "0.12" -regex-syntax = "0.8.1" +regex-syntax = "0.8" syn = "2.0.68" -url = "2.2" +url = "2" sqlparser = "0.47.0" [build-dependencies] From 761ff2b869c517afbc0417acae1ab93f854269de Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Tue, 6 Aug 2024 10:15:20 -0500 Subject: [PATCH 2/4] fix: reenable num_centroids argument for approx_percentile_cont --- python/datafusion/functions.py | 7 +++---- src/functions.rs | 10 ++++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 82b5056d7..2d3d87ee0 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1473,19 +1473,18 @@ def approx_median(arg: Expr, distinct: bool = False) -> Expr: def approx_percentile_cont( expression: Expr, percentile: Expr, + num_centroids: Expr | None = None, distinct: bool = False, ) -> Expr: """Returns the value that is approximately at a given percentile of ``expr``.""" - # Re-enable num_centroids: https://github.com/apache/datafusion-python/issues/777 - num_centroids = None if num_centroids is None: return Expr( - f.approx_percentile_cont(expression.expr, percentile.expr, distinct=distinct) + f.approx_percentile_cont(expression.expr, percentile.expr, distinct=distinct, num_centroids=None) ) return Expr( f.approx_percentile_cont( - expression.expr, percentile.expr, distinct=distinct + expression.expr, percentile.expr, distinct=distinct, num_centroids=num_centroids.expr ) ) diff --git a/src/functions.rs b/src/functions.rs index e60c63c8e..f8f478166 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -57,9 +57,15 @@ pub fn approx_percentile_cont( expression: PyExpr, percentile: PyExpr, distinct: bool, + num_centroids: Option, // enforces optional arguments at the end, currently ) -> PyResult { - let expr = - functions_aggregate::expr_fn::approx_percentile_cont(expression.expr, percentile.expr); + let args = if let Some(num_centroids) = num_centroids { + vec![expression.expr, percentile.expr, num_centroids.expr] + } else { + vec![expression.expr, percentile.expr] + }; + let udaf = functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf(); + let expr = udaf.call(args); if distinct { Ok(expr.distinct().build()?.into()) } else { From a02c9e2ff7f3b6318401530861ea97957a888a3c Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Tue, 6 Aug 2024 11:39:07 -0500 Subject: [PATCH 3/4] parametrize python aggregate tests --- python/datafusion/tests/test_aggregation.py | 111 ++++++++------------ 1 file changed, 43 insertions(+), 68 deletions(-) diff --git a/python/datafusion/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py index c10e5f36c..2a42cb8d1 100644 --- a/python/datafusion/tests/test_aggregation.py +++ b/python/datafusion/tests/test_aggregation.py @@ -39,78 +39,53 @@ def df(): ) return ctx.create_dataframe([[batch]]) +@pytest.mark.parametrize("agg_expr, calc_expected", [ + (f.avg(column("a")), lambda a, b, c, d: np.array(np.average(a))), + (f.corr(column("a"), column("b")), lambda a, b, c, d: np.array(np.corrcoef(a, b)[0][1])), + (f.count(column("a")), lambda a, b, c, d: pa.array([len(a)])), + # Sample (co)variance -> ddof=1 + # Population (co)variance -> ddof=0 + (f.covar(column("a"), column("b")), lambda a, b, c, d: np.array(np.cov(a, b, ddof=1)[0][1])), + (f.covar_pop(column("a"), column("c")), lambda a, b, c, d: np.array(np.cov(a, c, ddof=0)[0][1])), + (f.covar_samp(column("b"), column("c")), lambda a, b, c, d: np.array(np.cov(b, c, ddof=1)[0][1])), + # f.grouping(col_a), # No physical plan implemented yet + (f.max(column("a")), lambda a, b, c, d: np.array(np.max(a))), + (f.mean(column("b")), lambda a, b, c, d: np.array(np.mean(b))), + (f.median(column("b")), lambda a, b, c, d: np.array(np.median(b))), + (f.min(column("a")), lambda a, b, c, d: np.array(np.min(a))), + (f.sum(column("b")), lambda a, b, c, d: np.array(np.sum(b.to_pylist()))), + # Sample stdev -> ddof=1 + # Population stdev -> ddof=0 + (f.stddev(column("a")), lambda a, b, c, d: np.array(np.std(a, ddof=1))), + (f.stddev_pop(column("b")), lambda a, b, c, d: np.array(np.std(b, ddof=0))), + (f.stddev_samp(column("c")), lambda a, b, c, d: np.array(np.std(c, ddof=1))), + (f.var(column("a")), lambda a, b, c, d: np.array(np.var(a, ddof=1))), + (f.var_pop(column("b")), lambda a, b, c, d: np.array(np.var(b, ddof=0))), + (f.var_samp(column("c")), lambda a, b, c, d: np.array(np.var(c, ddof=1))), +]) +def test_aggregation_stats(df, agg_expr, calc_expected): -def test_built_in_aggregation(df): - col_a = column("a") - col_b = column("b") - col_c = column("c") - - agg_df = df.aggregate( - [], - [ - f.approx_distinct(col_b), - f.approx_median(col_b), - f.approx_percentile_cont(col_b, lit(0.5)), - f.approx_percentile_cont_with_weight(col_b, lit(0.6), lit(0.5)), - f.array_agg(col_b), - f.avg(col_a), - f.corr(col_a, col_b), - f.count(col_a), - f.covar(col_a, col_b), - f.covar_pop(col_a, col_c), - f.covar_samp(col_b, col_c), - # f.grouping(col_a), # No physical plan implemented yet - f.max(col_a), - f.mean(col_b), - f.median(col_b), - f.min(col_a), - f.sum(col_b), - f.stddev(col_a), - f.stddev_pop(col_b), - f.stddev_samp(col_c), - f.var(col_a), - f.var_pop(col_b), - f.var_samp(col_c), - ], - ) + agg_df = df.aggregate([], [agg_expr]) result = agg_df.collect()[0] values_a, values_b, values_c, values_d = df.collect()[0] + expected = calc_expected(values_a, values_b, values_c, values_d) + np.testing.assert_array_almost_equal(result.column(0), expected) - assert result.column(0) == pa.array([2], type=pa.uint64()) - assert result.column(1) == pa.array([4]) - assert result.column(2) == pa.array([4]) - # Ref: https://github.com/apache/datafusion-python/issues/777 - # assert result.column(3) == pa.array([6]) - assert result.column(4) == pa.array([[4, 4, 6]]) - np.testing.assert_array_almost_equal(result.column(5), np.average(values_a)) - np.testing.assert_array_almost_equal( - result.column(6), np.corrcoef(values_a, values_b)[0][1] - ) - assert result.column(7) == pa.array([len(values_a)]) - # Sample (co)variance -> ddof=1 - # Population (co)variance -> ddof=0 - np.testing.assert_array_almost_equal( - result.column(8), np.cov(values_a, values_b, ddof=1)[0][1] - ) - np.testing.assert_array_almost_equal( - result.column(9), np.cov(values_a, values_c, ddof=0)[0][1] - ) - np.testing.assert_array_almost_equal( - result.column(10), np.cov(values_b, values_c, ddof=1)[0][1] - ) - np.testing.assert_array_almost_equal(result.column(11), np.max(values_a)) - np.testing.assert_array_almost_equal(result.column(12), np.mean(values_b)) - np.testing.assert_array_almost_equal(result.column(13), np.median(values_b)) - np.testing.assert_array_almost_equal(result.column(14), np.min(values_a)) - np.testing.assert_array_almost_equal( - result.column(15), np.sum(values_b.to_pylist()) - ) - np.testing.assert_array_almost_equal(result.column(16), np.std(values_a, ddof=1)) - np.testing.assert_array_almost_equal(result.column(17), np.std(values_b, ddof=0)) - np.testing.assert_array_almost_equal(result.column(18), np.std(values_c, ddof=1)) - np.testing.assert_array_almost_equal(result.column(19), np.var(values_a, ddof=1)) - np.testing.assert_array_almost_equal(result.column(20), np.var(values_b, ddof=0)) - np.testing.assert_array_almost_equal(result.column(21), np.var(values_c, ddof=1)) + +@pytest.mark.parametrize("agg_expr, expected", [ + (f.approx_distinct(column("b")), pa.array([2], type=pa.uint64())), + (f.approx_median(column("b")), pa.array([4])), + (f.approx_percentile_cont(column("b"), lit(0.5)), pa.array([4])), + ( + f.approx_percentile_cont_with_weight(column("b"), lit(0.6), lit(0.5)), + pa.array([6], type=pa.float64()) + ), + (f.array_agg(column("b")), pa.array([[4, 4, 6]])), +]) +def test_aggregation(df, agg_expr, expected): + agg_df = df.aggregate([], [agg_expr]) + result = agg_df.collect()[0] + assert result.column(0) == expected def test_bit_add_or_xor(df): From f217fbc79396ec3f3cdc959cd0376f101ea6afda Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Tue, 6 Aug 2024 16:02:03 -0500 Subject: [PATCH 4/4] add num_centroids test --- python/datafusion/tests/test_aggregation.py | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/datafusion/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py index 2a42cb8d1..03485da4b 100644 --- a/python/datafusion/tests/test_aggregation.py +++ b/python/datafusion/tests/test_aggregation.py @@ -39,6 +39,13 @@ def df(): ) return ctx.create_dataframe([[batch]]) +@pytest.fixture +def df_aggregate_100(): + ctx = SessionContext() + ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv") + return ctx.table("aggregate_test_data") + + @pytest.mark.parametrize("agg_expr, calc_expected", [ (f.avg(column("a")), lambda a, b, c, d: np.array(np.average(a))), (f.corr(column("a"), column("b")), lambda a, b, c, d: np.array(np.corrcoef(a, b)[0][1])), @@ -88,6 +95,23 @@ def test_aggregation(df, agg_expr, expected): assert result.column(0) == expected +def test_aggregate_100(df_aggregate_100): + # https://github.com/apache/datafusion/blob/bddb6415a50746d2803dd908d19c3758952d74f9/datafusion/sqllogictest/test_files/aggregate.slt#L1490-L1498 + + result = df_aggregate_100.aggregate( + [ + column("c1") + ], + [ + f.approx_percentile_cont(column("c3"), lit(0.95), lit(200)).alias("c3") + ] + ).sort(column("c1").sort(ascending=True)).collect() + + assert len(result) == 1 + result = result[0] + assert result.column("c1") == pa.array(["a", "b", "c", "d", "e"]) + assert result.column("c3") == pa.array([73, 68, 122, 124, 115]) + def test_bit_add_or_xor(df): df = df.aggregate( [],