From 5dac633874469f8dd86682f67687ce7971729936 Mon Sep 17 00:00:00 2001 From: Tim Saucer <timsaucer@gmail.com> Date: Sat, 1 Mar 2025 14:45:46 +0100 Subject: [PATCH 1/8] Improve table readout of a dataframe in jupyter notebooks by making the table scrollable and displaying the first record batch up to 2MB --- src/dataframe.rs | 98 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 29 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 243e2e14f..09fd92a78 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions; use datafusion::config::{CsvOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::TableProvider; +use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; +use futures::{StreamExt, TryStreamExt}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -70,6 +72,7 @@ impl PyTableProvider { PyTable::new(table_provider) } } +const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. @@ -121,46 +124,57 @@ impl PyDataFrame { } fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> { - let mut html_str = "<table border='1'>\n".to_string(); - - let df = self.df.as_ref().clone().limit(0, Some(10))?; - let batches = wait_for_future(py, df.collect())?; + let (batch, mut has_more) = + wait_for_future(py, get_first_record_batch(self.df.as_ref().clone()))?; + let Some(batch) = batch else { + return Ok("No data to display".to_string()); + }; - if batches.is_empty() { - html_str.push_str("</table>\n"); - return Ok(html_str); - } + let mut html_str = " + <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\"> + <table style=\"border-collapse: collapse; min-width: 100%\"> + <thead>\n".to_string(); - let schema = batches[0].schema(); + let schema = batch.schema(); let mut header = Vec::new(); for field in schema.fields() { - header.push(format!("<th>{}</td>", field.name())); + header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name())); } let header_str = header.join(""); - html_str.push_str(&format!("<tr>{}</tr>\n", header_str)); + html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str)); + + let formatters = batch + .columns() + .iter() + .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default())) + .map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))) + .collect::<Result<Vec<_>, _>>()?; + + let batch_size = batch.get_array_memory_size(); + let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY { + true => { + has_more = true; + let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32; + (batch.num_rows() as f32 * ratio).round() as usize + } + false => batch.num_rows(), + }; - for batch in batches { - let formatters = batch - .columns() - .iter() - .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default())) - .map(|c| { - c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))) - }) - .collect::<Result<Vec<_>, _>>()?; - - for row in 0..batch.num_rows() { - let mut cells = Vec::new(); - for formatter in &formatters { - cells.push(format!("<td>{}</td>", formatter.value(row))); - } - let row_str = cells.join(""); - html_str.push_str(&format!("<tr>{}</tr>\n", row_str)); + for row in 0..num_rows_to_display { + let mut cells = Vec::new(); + for formatter in &formatters { + cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row))); } + let row_str = cells.join(""); + html_str.push_str(&format!("<tr>{}</tr>\n", row_str)); } - html_str.push_str("</table>\n"); + html_str.push_str("</tbody></table></div>\n"); + + if has_more { + html_str.push_str("Data truncated due to size."); + } Ok(html_str) } @@ -771,3 +785,29 @@ fn record_batch_into_schema( RecordBatch::try_new(schema, data_arrays) } + +/// This is a helper function to return the first non-empty record batch from executing a DataFrame. +/// It additionally returns a bool, which indicates if there are more record batches available. +/// We do this so we can determine if we should indicate to the user that the data has been +/// truncated. +async fn get_first_record_batch( + df: DataFrame, +) -> Result<(Option<RecordBatch>, bool), DataFusionError> { + let mut stream = df.execute_stream().await?; + loop { + let rb = match stream.next().await { + None => return Ok((None, false)), + Some(Ok(r)) => r, + Some(Err(e)) => return Err(e), + }; + + if rb.num_rows() > 0 { + let has_more = match stream.try_next().await { + Ok(None) => false, // reached end + Ok(Some(_)) => true, + Err(_) => false, // Stream disconnected + }; + return Ok((Some(rb), has_more)); + } + } +} From 63074981773da2cd0232b7fb07526e4f3bc7bb9e Mon Sep 17 00:00:00 2001 From: Tim Saucer <timsaucer@gmail.com> Date: Sat, 1 Mar 2025 15:48:16 +0100 Subject: [PATCH 2/8] Add option to only display a portion of a cell data and the user can click on a button to toggle showing more or less --- src/dataframe.rs | 80 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 09fd92a78..211c6a4e7 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -73,6 +73,8 @@ impl PyTableProvider { } } const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB +const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20; +const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. @@ -130,7 +132,37 @@ impl PyDataFrame { return Ok("No data to display".to_string()); }; + let table_uuid = uuid::Uuid::new_v4().to_string(); + let mut html_str = " + <style> + .expandable-container { + display: inline-block; + max-width: 200px; + } + .expandable { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + display: block; + } + .full-text { + display: none; + white-space: normal; + } + .expand-btn { + cursor: pointer; + color: blue; + text-decoration: underline; + border: none; + background: none; + font-size: inherit; + display: block; + margin-top: 5px; + } + </style> + + <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\"> <table style=\"border-collapse: collapse; min-width: 100%\"> <thead>\n".to_string(); @@ -154,17 +186,37 @@ impl PyDataFrame { let batch_size = batch.get_array_memory_size(); let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY { true => { - has_more = true; + let num_batch_rows = batch.num_rows(); let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32; - (batch.num_rows() as f32 * ratio).round() as usize + let mut reduced_row_num = (num_batch_rows as f32 * ratio).round() as usize; + if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY { + reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY.min(num_batch_rows); + } + + has_more = has_more || reduced_row_num < num_batch_rows; + reduced_row_num } false => batch.num_rows(), }; for row in 0..num_rows_to_display { let mut cells = Vec::new(); - for formatter in &formatters { - cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row))); + for (col, formatter) in formatters.iter().enumerate() { + let cell_data = formatter.value(row).to_string(); + // From testing, primitive data types do not typically get larger than 21 characters + if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE { + let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE]; + cells.push(format!(" + <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'> + <div class=\"expandable-container\"> + <span class=\"expandable\" id=\"{table_uuid}-min-text-{row}-{col}\">{short_cell_data}</span> + <span class=\"full-text\" id=\"{table_uuid}-full-text-{row}-{col}\">{cell_data}</span> + <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{row},{col})\">...</button> + </div> + </td>")); + } else { + cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row))); + } } let row_str = cells.join(""); html_str.push_str(&format!("<tr>{}</tr>\n", row_str)); @@ -172,6 +224,26 @@ impl PyDataFrame { html_str.push_str("</tbody></table></div>\n"); + html_str.push_str(" + <script> + function toggleDataFrameCellText(table_uuid, row, col) { + var shortText = document.getElementById(table_uuid + \"-min-text-\" + row + \"-\" + col); + var fullText = document.getElementById(table_uuid + \"-full-text-\" + row + \"-\" + col); + var button = event.target; + + if (fullText.style.display === \"none\") { + shortText.style.display = \"none\"; + fullText.style.display = \"inline\"; + button.textContent = \"(less)\"; + } else { + shortText.style.display = \"inline\"; + fullText.style.display = \"none\"; + button.textContent = \"...\"; + } + } + </script> + "); + if has_more { html_str.push_str("Data truncated due to size."); } From f7c1861106f40eccc5fb4ff5f53c11ea30229d70 Mon Sep 17 00:00:00 2001 From: Tim Saucer <timsaucer@gmail.com> Date: Sat, 1 Mar 2025 16:31:23 +0100 Subject: [PATCH 3/8] We cannot expect that the first non-empy batch is sufficient for our 2MB limit, so switch over to collecting until we run out or use up the size --- src/dataframe.rs | 132 ++++++++++++++++++++++++++++++----------------- 1 file changed, 85 insertions(+), 47 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 211c6a4e7..eb071272a 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -126,11 +126,15 @@ impl PyDataFrame { } fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> { - let (batch, mut has_more) = - wait_for_future(py, get_first_record_batch(self.df.as_ref().clone()))?; - let Some(batch) = batch else { + let (batches, mut has_more) = + wait_for_future(py, get_first_few_record_batches(self.df.as_ref().clone()))?; + let Some(batches) = batches else { return Ok("No data to display".to_string()); }; + if batches.is_empty() { + // This should not be reached, but do it for safety since we index into the vector below + return Ok("No data to display".to_string()); + } let table_uuid = uuid::Uuid::new_v4().to_string(); @@ -162,12 +166,11 @@ impl PyDataFrame { } </style> - <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\"> <table style=\"border-collapse: collapse; min-width: 100%\"> <thead>\n".to_string(); - let schema = batch.schema(); + let schema = batches[0].schema(); let mut header = Vec::new(); for field in schema.fields() { @@ -176,52 +179,75 @@ impl PyDataFrame { let header_str = header.join(""); html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str)); - let formatters = batch - .columns() + let batch_formatters = batches .iter() - .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default())) - .map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))) + .map(|batch| { + batch + .columns() + .iter() + .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default())) + .map(|c| { + c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))) + }) + .collect::<Result<Vec<_>, _>>() + }) .collect::<Result<Vec<_>, _>>()?; - let batch_size = batch.get_array_memory_size(); - let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY { + let total_memory: usize = batches + .iter() + .map(|batch| batch.get_array_memory_size()) + .sum(); + let rows_per_batch = batches.iter().map(|batch| batch.num_rows()); + let total_rows = rows_per_batch.clone().sum(); + + // let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| { + // (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows()) + // }); + + let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY { true => { - let num_batch_rows = batch.num_rows(); - let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32; - let mut reduced_row_num = (num_batch_rows as f32 * ratio).round() as usize; + let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32; + let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize; if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY { - reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY.min(num_batch_rows); + reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY.min(total_rows); } - has_more = has_more || reduced_row_num < num_batch_rows; + has_more = has_more || reduced_row_num < total_rows; reduced_row_num } - false => batch.num_rows(), + false => total_rows, }; - for row in 0..num_rows_to_display { - let mut cells = Vec::new(); - for (col, formatter) in formatters.iter().enumerate() { - let cell_data = formatter.value(row).to_string(); - // From testing, primitive data types do not typically get larger than 21 characters - if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE { - let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE]; - cells.push(format!(" - <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'> - <div class=\"expandable-container\"> - <span class=\"expandable\" id=\"{table_uuid}-min-text-{row}-{col}\">{short_cell_data}</span> - <span class=\"full-text\" id=\"{table_uuid}-full-text-{row}-{col}\">{cell_data}</span> - <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{row},{col})\">...</button> - </div> - </td>")); - } else { - cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row))); + // We need to build up row by row for html + let mut table_row = 0; + for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) { + for batch_row in 0..num_rows_in_batch { + table_row += 1; + if table_row > num_rows_to_display { + break; + } + let mut cells = Vec::new(); + for (col, formatter) in batch_formatter.iter().enumerate() { + let cell_data = formatter.value(batch_row).to_string(); + // From testing, primitive data types do not typically get larger than 21 characters + if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE { + let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE]; + cells.push(format!(" + <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'> + <div class=\"expandable-container\"> + <span class=\"expandable\" id=\"{table_uuid}-min-text-{table_row}-{col}\">{short_cell_data}</span> + <span class=\"full-text\" id=\"{table_uuid}-full-text-{table_row}-{col}\">{cell_data}</span> + <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{table_row},{col})\">...</button> + </div> + </td>")); + } else { + cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(batch_row))); + } } + let row_str = cells.join(""); + html_str.push_str(&format!("<tr>{}</tr>\n", row_str)); } - let row_str = cells.join(""); - html_str.push_str(&format!("<tr>{}</tr>\n", row_str)); } - html_str.push_str("</tbody></table></div>\n"); html_str.push_str(" @@ -862,24 +888,36 @@ fn record_batch_into_schema( /// It additionally returns a bool, which indicates if there are more record batches available. /// We do this so we can determine if we should indicate to the user that the data has been /// truncated. -async fn get_first_record_batch( +async fn get_first_few_record_batches( df: DataFrame, -) -> Result<(Option<RecordBatch>, bool), DataFusionError> { +) -> Result<(Option<Vec<RecordBatch>>, bool), DataFusionError> { let mut stream = df.execute_stream().await?; - loop { + let mut size_estimate_so_far = 0; + let mut record_batches = Vec::default(); + while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY { let rb = match stream.next().await { - None => return Ok((None, false)), + None => { + break; + } Some(Ok(r)) => r, Some(Err(e)) => return Err(e), }; if rb.num_rows() > 0 { - let has_more = match stream.try_next().await { - Ok(None) => false, // reached end - Ok(Some(_)) => true, - Err(_) => false, // Stream disconnected - }; - return Ok((Some(rb), has_more)); + size_estimate_so_far += rb.get_array_memory_size(); + record_batches.push(rb); } } + + if record_batches.is_empty() { + return Ok((None, false)); + } + + let has_more = match stream.try_next().await { + Ok(None) => false, // reached end + Ok(Some(_)) => true, + Err(_) => false, // Stream disconnected + }; + + Ok((Some(record_batches), has_more)) } From ee3864b8864e195e533425aa5d5c6452fbb1f592 Mon Sep 17 00:00:00 2001 From: Tim Saucer <timsaucer@gmail.com> Date: Sat, 1 Mar 2025 20:36:04 +0100 Subject: [PATCH 4/8] Update python unit test to allow the additional formatting data to exist and only check the table contents --- python/tests/test_dataframe.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 384b17878..718ebf69d 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import os +import re from typing import Any import pyarrow as pa @@ -1245,13 +1246,17 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame: def test_dataframe_repr_html(df) -> None: output = df._repr_html_() - ref_html = """<table border='1'> - <tr><th>a</td><th>b</td><th>c</td></tr> - <tr><td>1</td><td>4</td><td>8</td></tr> - <tr><td>2</td><td>5</td><td>5</td></tr> - <tr><td>3</td><td>6</td><td>8</td></tr> - </table> - """ + # Since we've added a fair bit of processing to the html output, lets just verify + # the values we are expecting in the table exist. Use regex and ignore everything + # between the <th></th> and <td></td>. We also don't want the closing > on the + # td and th segments because that is where the formatting data is written. - # Ignore whitespace just to make this test look cleaner - assert output.replace(" ", "") == ref_html.replace(" ", "") + headers = ["a", "b", "c"] + headers = [f"<th(.*?)>{v}</th>" for v in headers] + header_pattern = "(.*?)".join(headers) + assert len(re.findall(header_pattern, output, re.DOTALL)) == 1 + + body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]] + body_lines = [f"<td(.*?)>{v}</td>" for inner in body_data for v in inner] + body_pattern = "(.*?)".join(body_lines) + assert len(re.findall(body_pattern, output, re.DOTALL)) == 1 From 933d48ce72b9d9d458f86ce9fa4791bc0740c469 Mon Sep 17 00:00:00 2001 From: Tim Saucer <timsaucer@gmail.com> Date: Wed, 12 Mar 2025 19:05:39 -0400 Subject: [PATCH 5/8] Combining collection for repr and repr_html into one function --- src/dataframe.rs | 107 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 80 insertions(+), 27 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index eb071272a..4cf61d23a 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -116,21 +116,37 @@ impl PyDataFrame { } fn __repr__(&self, py: Python) -> PyDataFusionResult<String> { + let (batches, has_more) = wait_for_future( + py, + collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10), + )?; + if batches.is_empty() { + // This should not be reached, but do it for safety since we index into the vector below + return Ok("No data to display".to_string()); + } + let df = self.df.as_ref().clone().limit(0, Some(10))?; let batches = wait_for_future(py, df.collect())?; - let batches_as_string = pretty::pretty_format_batches(&batches); - match batches_as_string { - Ok(batch) => Ok(format!("DataFrame()\n{batch}")), - Err(err) => Ok(format!("Error: {:?}", err.to_string())), - } + let batches_as_displ = + pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?; + + let additional_str = match has_more { + true => "\nData truncated.", + false => "", + }; + + Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) } fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> { - let (batches, mut has_more) = - wait_for_future(py, get_first_few_record_batches(self.df.as_ref().clone()))?; - let Some(batches) = batches else { - return Ok("No data to display".to_string()); - }; + let (batches, mut has_more) = wait_for_future( + py, + collect_record_batches_to_display( + self.df.as_ref().clone(), + MIN_TABLE_ROWS_TO_DISPLAY, + usize::MAX, + ), + )?; if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below return Ok("No data to display".to_string()); @@ -200,10 +216,6 @@ impl PyDataFrame { let rows_per_batch = batches.iter().map(|batch| batch.num_rows()); let total_rows = rows_per_batch.clone().sum(); - // let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| { - // (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows()) - // }); - let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY { true => { let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32; @@ -887,15 +899,28 @@ fn record_batch_into_schema( /// This is a helper function to return the first non-empty record batch from executing a DataFrame. /// It additionally returns a bool, which indicates if there are more record batches available. /// We do this so we can determine if we should indicate to the user that the data has been -/// truncated. -async fn get_first_few_record_batches( +/// truncated. This collects until we have achived both of these two conditions +/// +/// - We have collected our minimum number of rows +/// - We have reached our limit, either data size or maximum number of rows +/// +/// Otherwise it will return when the stream has exhausted. If you want a specific number of +/// rows, set min_rows == max_rows. +async fn collect_record_batches_to_display( df: DataFrame, -) -> Result<(Option<Vec<RecordBatch>>, bool), DataFusionError> { + min_rows: usize, + max_rows: usize, +) -> Result<(Vec<RecordBatch>, bool), DataFusionError> { let mut stream = df.execute_stream().await?; let mut size_estimate_so_far = 0; + let mut rows_so_far = 0; let mut record_batches = Vec::default(); - while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY { - let rb = match stream.next().await { + let mut has_more = false; + + while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows) + || rows_so_far < min_rows + { + let mut rb = match stream.next().await { None => { break; } @@ -903,21 +928,49 @@ async fn get_first_few_record_batches( Some(Err(e)) => return Err(e), }; - if rb.num_rows() > 0 { + let mut rows_in_rb = rb.num_rows(); + if rows_in_rb > 0 { size_estimate_so_far += rb.get_array_memory_size(); + + if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY { + let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32; + let total_rows = rows_in_rb + rows_so_far; + + let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize; + if reduced_row_num < min_rows { + reduced_row_num = min_rows.min(total_rows); + } + + let limited_rows_this_rb = reduced_row_num - rows_so_far; + if limited_rows_this_rb < rows_in_rb { + rows_in_rb = limited_rows_this_rb; + rb = rb.slice(0, limited_rows_this_rb); + has_more = true; + } + } + + if rows_in_rb + rows_so_far > max_rows { + rb = rb.slice(0, max_rows - rows_so_far); + has_more = true; + } + + rows_so_far += rb.num_rows(); record_batches.push(rb); } } if record_batches.is_empty() { - return Ok((None, false)); + return Ok((Vec::default(), false)); } - let has_more = match stream.try_next().await { - Ok(None) => false, // reached end - Ok(Some(_)) => true, - Err(_) => false, // Stream disconnected - }; + if !has_more { + // Data was not already truncated, so check to see if more record batches remain + has_more = match stream.try_next().await { + Ok(None) => false, // reached end + Ok(Some(_)) => true, + Err(_) => false, // Stream disconnected + }; + } - Ok((Some(record_batches), has_more)) + Ok((record_batches, has_more)) } From 3ba4d9d395f92b320e7e31b283721e37c6360ee9 Mon Sep 17 00:00:00 2001 From: Tim Saucer <timsaucer@gmail.com> Date: Wed, 12 Mar 2025 19:06:24 -0400 Subject: [PATCH 6/8] Small clippy suggestion --- src/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.rs b/src/utils.rs index 999aad755..3487de21b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -42,7 +42,7 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { #[inline] pub(crate) fn get_global_ctx() -> &'static SessionContext { static CTX: OnceLock<SessionContext> = OnceLock::new(); - CTX.get_or_init(|| SessionContext::new()) + CTX.get_or_init(SessionContext::new) } /// Utility to collect rust futures with GIL released From cacba9b471d63ff63dc0292ae9fd5389a331bc54 Mon Sep 17 00:00:00 2001 From: Tim Saucer <timsaucer@gmail.com> Date: Thu, 13 Mar 2025 07:36:16 -0400 Subject: [PATCH 7/8] Collect was occuring twice on repr --- src/dataframe.rs | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 4cf61d23a..49fd08642 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -125,8 +125,6 @@ impl PyDataFrame { return Ok("No data to display".to_string()); } - let df = self.df.as_ref().clone().limit(0, Some(10))?; - let batches = wait_for_future(py, df.collect())?; let batches_as_displ = pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?; @@ -139,7 +137,7 @@ impl PyDataFrame { } fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> { - let (batches, mut has_more) = wait_for_future( + let (batches, has_more) = wait_for_future( py, collect_record_batches_to_display( self.df.as_ref().clone(), @@ -209,35 +207,13 @@ impl PyDataFrame { }) .collect::<Result<Vec<_>, _>>()?; - let total_memory: usize = batches - .iter() - .map(|batch| batch.get_array_memory_size()) - .sum(); let rows_per_batch = batches.iter().map(|batch| batch.num_rows()); - let total_rows = rows_per_batch.clone().sum(); - - let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY { - true => { - let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32; - let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize; - if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY { - reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY.min(total_rows); - } - - has_more = has_more || reduced_row_num < total_rows; - reduced_row_num - } - false => total_rows, - }; // We need to build up row by row for html let mut table_row = 0; for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) { for batch_row in 0..num_rows_in_batch { table_row += 1; - if table_row > num_rows_to_display { - break; - } let mut cells = Vec::new(); for (col, formatter) in batch_formatter.iter().enumerate() { let cell_data = formatter.value(batch_row).to_string(); From 2882050e95c6baa5b0e4250786e6fd537685d1e8 Mon Sep 17 00:00:00 2001 From: Tim Saucer <timsaucer@gmail.com> Date: Sat, 22 Mar 2025 08:57:24 -0400 Subject: [PATCH 8/8] Switch to execute_stream_partitioned --- src/dataframe.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 49fd08642..be10b8c28 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -887,7 +887,8 @@ async fn collect_record_batches_to_display( min_rows: usize, max_rows: usize, ) -> Result<(Vec<RecordBatch>, bool), DataFusionError> { - let mut stream = df.execute_stream().await?; + let partitioned_stream = df.execute_stream_partitioned().await?; + let mut stream = futures::stream::iter(partitioned_stream).flatten(); let mut size_estimate_so_far = 0; let mut rows_so_far = 0; let mut record_batches = Vec::default();