You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi I have scenario where in one node gets bunch of record batches from worker nodes and then it needs to aggregate them. I have a working code but I want to understand if this is the best way to do it or there are other abstractions available in the datafusion library which can reduce the boiler plate code here.
My approach is as follows:
It DataFusion's StreamingTable to process data arriving from a Tokio MPSC channel.
A custom struct StreamPartition implements the PartitionStream trait. Its execute() method is key:
It takes a ReceiverStream<DFResult> (from the MPSC channel).
It uses RecordBatchStreamAdapter to wrap this stream, making it a SendableRecordBatchStream that DataFusion's execution engine can consume.
An instance of StreamingTable is created using:
The SchemaRef of the incoming data.
A Vec containing Arcs of our StreamPartition (one for each partition, though here it's effectively a single stream source).
This Arc (which is a TableProvider) is then registered with the SessionContext using ctx.register_table("table_name", table_provider).
This setup allows DataFusion to query the MPSC channel as if it were a SQL table, with StreamingTable and our StreamPartition implementation handling the mechanics of pulling data in RecordBatches when the query executes.
Here is the working code sample:
use std::fs::File;use std::sync::{Arc,Mutex};use std::time::SystemTime;use std::fmt::Debug;use std::env;// Added for command-line argumentsuse std::path::Path;// Added for path checking// Explicit imports from arrow sub-cratesuse arrow_array::builder::{Float32Builder,Float64Builder,Int32Builder,Int64Builder,StringBuilder,TimestampNanosecondBuilder};use arrow_array::RecordBatch;use arrow_ipc::writer::{StreamWriter,IpcWriteOptions};use arrow_ipc::reader::StreamReader;use arrow_schema::{DataType,Field,Schema,SchemaRef,TimeUnit,ArrowError};use datafusion::arrow::compute::concat_batches;use datafusion::error::{DataFusionError,ResultasDFResult};use datafusion::prelude::*;use datafusion::physical_plan::stream::{EmptyRecordBatchStream,RecordBatchStreamAdapter};use datafusion::physical_plan::streaming::PartitionStream;use datafusion::physical_plan::SendableRecordBatchStream;use datafusion::execution::context::TaskContext;use datafusion::catalog::streaming::StreamingTable;use futures::{StreamExt,TryStreamExt};use tokio::sync::mpsc;use tokio_stream::wrappers::ReceiverStream;use async_trait::async_trait;constDEFAULT_ARROW_STREAM_FILE_PATH:&str = "input.arrow_stream";// Renamed from ARROW_STREAM_FILE_PATHconstNUM_STREAM_REPETITIONS:usize = 10;constCHANNEL_BUFFER_SIZE:usize = 100;/// Reads the schema from the Arrow IPC stream file.fnget_schema_from_file(file_path:&str) -> DFResult<SchemaRef>{let file = File::open(file_path)?;let reader = StreamReader::try_new(file,None)?;Ok(reader.schema())}/// Producer task: Reads RecordBatches and sends them over MPSC channel.asyncfnproducer_task(tx: mpsc::Sender<DFResult<RecordBatch>>,file_path:String) -> DFResult<()>{// Added file_path argumentprintln!("Producer: Starting to read {} and send batches.", file_path);// Use argumentfor i in0..NUM_STREAM_REPETITIONS{let file = File::open(&file_path)?;// Use argumentletmut reader = StreamReader::try_new(file,None)?;println!("Producer: Sending repetition #{}", i + 1);whileletSome(batch_result) = reader.next(){match batch_result {Ok(batch) => {if tx.send(Ok(batch)).await.is_err(){eprintln!("Producer: Receiver closed, stopping.");returnOk(());}}Err(e) => {eprintln!("Producer: Error reading batch: {:?}", e);let df_error = DataFusionError::ArrowError(e,None);if tx.send(Err(df_error)).await.is_err(){eprintln!("Producer: Receiver closed while sending error, stopping.");returnOk(());}}}}}println!("Producer: Finished sending all batches.");Ok(())}/// Builds the SQL query string for aggregation based on the schema and rules.fnbuild_sql_query(schema:&Schema,table_name:&str) -> Result<String,String>{letmut select_clauses = Vec::new();letmut groupby_clauses = Vec::new();for field in schema.fields(){let name = field.name();let quoted_name = format!("\"{}\"", name.replace("\"","\"\""));match field.data_type(){DataType::Utf8 | DataType::LargeUtf8 | DataType::Timestamp(_, _) => {
select_clauses.push(quoted_name.clone());
groupby_clauses.push(quoted_name);}DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 |
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 |
DataType::Float16 | DataType::Float32 | DataType::Float64 => {if name.ends_with("_sum") || name.ends_with("_count"){
select_clauses.push(format!("SUM({}) AS {}", quoted_name, quoted_name));}elseif name.ends_with("_min"){
select_clauses.push(format!("MIN({}) AS {}", quoted_name, quoted_name));}elseif name.ends_with("_max"){
select_clauses.push(format!("MAX({}) AS {}", quoted_name, quoted_name));}else{
select_clauses.push(format!("SUM({}) AS {}", quoted_name, quoted_name));}}
_ => {println!("Skipping column '{}' with unhandled type: {:?}", name, field.data_type());}}}if select_clauses.is_empty(){returnErr("No columns suitable for selecting or grouping found.".to_string());}let select_statement = select_clauses.join(", ");if groupby_clauses.is_empty(){Ok(format!("SELECT {} FROM \"{}\"", select_statement, table_name))}else{Ok(format!("SELECT {} FROM \"{}\" GROUP BY {}",
select_statement,
table_name,
groupby_clauses.join(", ")))}}/// PartitionStream wrapper that holds the stream within a Mutex/// to ensure it's Sync and consumed only once./// It takes the stream yielding DataFusionError and uses RecordBatchStreamAdapter.structStreamPartition{schema:SchemaRef,// Store the original stream yielding DataFusionErrorstream_mutex:Arc<Mutex<Option<ReceiverStream<DFResult<RecordBatch>>>>>,}implDebugforStreamPartition{fnfmt(&self,f:&mut std::fmt::Formatter<'_>) -> std::fmt::Result{
f.debug_struct("StreamPartition").field("schema",&self.schema).field("stream_mutex",&"<Mutex containing Option<ReceiverStream<DFResult<RecordBatch>>>>").finish()}}#[async_trait]implPartitionStreamforStreamPartition{fnschema(&self) -> &SchemaRef{&self.schema}fnexecute(&self,_ctx:Arc<TaskContext>) -> SendableRecordBatchStream{letmut stream_guard = self.stream_mutex.lock().unwrap();match stream_guard.take(){Some(stream) => {let record_batch_stream_adapter = RecordBatchStreamAdapter::new(self.schema.clone(),
stream
);Box::pin(record_batch_stream_adapter)asSendableRecordBatchStream},None => {println!("Warning: PartitionStream executed more than once. Returning empty stream.");Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))}}}}/// Aggregator task: Receives batches, sets up DataFusion streaming table,/// executes aggregation query, and prints results.asyncfnaggregator_task(rx: mpsc::Receiver<DFResult<RecordBatch>>,input_schema:SchemaRef,) -> DFResult<()>{println!("Aggregator: Task started. Waiting for batches...");let config = SessionConfig::new().with_batch_size(8192);let ctx = SessionContext::new_with_config(config);let stream_from_channel = ReceiverStream::new(rx);let partition_stream:Arc<dynPartitionStream> = Arc::new(StreamPartition{schema: input_schema.clone(),stream_mutex:Arc::new(Mutex::new(Some(stream_from_channel))),});let table_name = "streaming_sensor_data";let table_provider = Arc::new(StreamingTable::try_new(
input_schema.clone(),vec![partition_stream])?);
ctx.register_table(table_name, table_provider)?;println!("Aggregator: Streaming table '{}' registered.", table_name);let sql_query = matchbuild_sql_query(&input_schema, table_name){Ok(q) => q,Err(e) => {eprintln!("Aggregator: Failed to build SQL query: {}", e);returnErr(DataFusionError::Execution(e));}};println!("Aggregator: Executing SQL query:\n{}", sql_query);let df = ctx.sql(&sql_query).await?;letmut result_batches = df.collect().await?;println!("\n--- Aggregated Results ---");letmut total_result_rows = 0;// merge the results into a single batchlet merged_batch = concat_batches(&input_schema,&result_batches)?;println!("\n--- Merged Aggregated Results ---");if merged_batch.num_rows() > 0{println!("Merged batch schema: {:?}", merged_batch.schema());
arrow::util::pretty::print_batches(&[merged_batch])?;}else{println!("No rows in the merged aggregated result.");}if total_result_rows == 0{println!("Aggregator: No rows in the aggregated result.");}else{println!("Aggregator: Finished processing. Total aggregated rows: {}", total_result_rows);}Ok(())}#[tokio::main]asyncfnmain() -> DFResult<()>{let file_path_to_use:String;ifletSome(arg_path) = env::args().nth(1){
file_path_to_use = arg_path;}else{
file_path_to_use = DEFAULT_ARROW_STREAM_FILE_PATH.to_string();println!("No file path argument provided. Using default path: {}", file_path_to_use);returnErr();}let input_schema = matchget_schema_from_file(&file_path_to_use){Ok(s) => s,Err(e) => {eprintln!("Critical: Failed to read schema from {}: {:?}", file_path_to_use, e);returnErr(e.into());}};println!("Read input schema: {:?} from {}", input_schema.fields(), file_path_to_use);let(tx, rx) = mpsc::channel::<DFResult<RecordBatch>>(CHANNEL_BUFFER_SIZE);let producer_file_path_clone = file_path_to_use.clone();let producer_handle = tokio::spawn(asyncmove{ifletErr(e) = producer_task(tx, producer_file_path_clone).await{eprintln!("Producer task error: {:?}", e);}});let aggregator_handle = tokio::spawn(asyncmove{ifletErr(e) = aggregator_task(rx, input_schema).await{eprintln!("Aggregator task error: {:?}", e);}});let(producer_res, aggregator_res) = tokio::join!(producer_handle, aggregator_handle);if producer_res.is_err(){eprintln!("Producer task panicked or encountered an unhandled error.");}if aggregator_res.is_err(){eprintln!("Aggregator task panicked or encountered an unhandled error.");}println!("Processing complete.");Ok(())}
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi I have scenario where in one node gets bunch of record batches from worker nodes and then it needs to aggregate them. I have a working code but I want to understand if this is the best way to do it or there are other abstractions available in the datafusion library which can reduce the boiler plate code here.
My approach is as follows:
It DataFusion's StreamingTable to process data arriving from a Tokio MPSC channel.
A custom struct StreamPartition implements the PartitionStream trait. Its execute() method is key:
It takes a ReceiverStream<DFResult> (from the MPSC channel).
It uses RecordBatchStreamAdapter to wrap this stream, making it a SendableRecordBatchStream that DataFusion's execution engine can consume.
An instance of StreamingTable is created using:
The SchemaRef of the incoming data.
A Vec containing Arcs of our StreamPartition (one for each partition, though here it's effectively a single stream source).
This Arc (which is a TableProvider) is then registered with the SessionContext using ctx.register_table("table_name", table_provider).
This setup allows DataFusion to query the MPSC channel as if it were a SQL table, with StreamingTable and our StreamPartition implementation handling the mechanics of pulling data in RecordBatches when the query executes.
Here is the working code sample:
Beta Was this translation helpful? Give feedback.
All reactions