From 5dac633874469f8dd86682f67687ce7971729936 Mon Sep 17 00:00:00 2001
From: Tim Saucer <timsaucer@gmail.com>
Date: Sat, 1 Mar 2025 14:45:46 +0100
Subject: [PATCH 1/8] Improve table readout of a dataframe in jupyter notebooks
 by making the table scrollable and displaying the first record batch up to
 2MB

---
 src/dataframe.rs | 98 ++++++++++++++++++++++++++++++++++--------------
 1 file changed, 69 insertions(+), 29 deletions(-)

diff --git a/src/dataframe.rs b/src/dataframe.rs
index 243e2e14f..09fd92a78 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions;
 use datafusion::config::{CsvOptions, TableParquetOptions};
 use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
 use datafusion::datasource::TableProvider;
+use datafusion::error::DataFusionError;
 use datafusion::execution::SendableRecordBatchStream;
 use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
 use datafusion::prelude::*;
+use futures::{StreamExt, TryStreamExt};
 use pyo3::exceptions::PyValueError;
 use pyo3::prelude::*;
 use pyo3::pybacked::PyBackedStr;
@@ -70,6 +72,7 @@ impl PyTableProvider {
         PyTable::new(table_provider)
     }
 }
+const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
 
 /// A PyDataFrame is a representation of a logical plan and an API to compose statements.
 /// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -121,46 +124,57 @@ impl PyDataFrame {
     }
 
     fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
-        let mut html_str = "<table border='1'>\n".to_string();
-
-        let df = self.df.as_ref().clone().limit(0, Some(10))?;
-        let batches = wait_for_future(py, df.collect())?;
+        let (batch, mut has_more) =
+            wait_for_future(py, get_first_record_batch(self.df.as_ref().clone()))?;
+        let Some(batch) = batch else {
+            return Ok("No data to display".to_string());
+        };
 
-        if batches.is_empty() {
-            html_str.push_str("</table>\n");
-            return Ok(html_str);
-        }
+        let mut html_str = "
+        <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
+            <table style=\"border-collapse: collapse; min-width: 100%\">
+                <thead>\n".to_string();
 
-        let schema = batches[0].schema();
+        let schema = batch.schema();
 
         let mut header = Vec::new();
         for field in schema.fields() {
-            header.push(format!("<th>{}</td>", field.name()));
+            header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name()));
         }
         let header_str = header.join("");
-        html_str.push_str(&format!("<tr>{}</tr>\n", header_str));
+        html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str));
+
+        let formatters = batch
+            .columns()
+            .iter()
+            .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
+            .map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))))
+            .collect::<Result<Vec<_>, _>>()?;
+
+        let batch_size = batch.get_array_memory_size();
+        let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
+            true => {
+                has_more = true;
+                let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32;
+                (batch.num_rows() as f32 * ratio).round() as usize
+            }
+            false => batch.num_rows(),
+        };
 
-        for batch in batches {
-            let formatters = batch
-                .columns()
-                .iter()
-                .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
-                .map(|c| {
-                    c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
-                })
-                .collect::<Result<Vec<_>, _>>()?;
-
-            for row in 0..batch.num_rows() {
-                let mut cells = Vec::new();
-                for formatter in &formatters {
-                    cells.push(format!("<td>{}</td>", formatter.value(row)));
-                }
-                let row_str = cells.join("");
-                html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
+        for row in 0..num_rows_to_display {
+            let mut cells = Vec::new();
+            for formatter in &formatters {
+                cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row)));
             }
+            let row_str = cells.join("");
+            html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
         }
 
-        html_str.push_str("</table>\n");
+        html_str.push_str("</tbody></table></div>\n");
+
+        if has_more {
+            html_str.push_str("Data truncated due to size.");
+        }
 
         Ok(html_str)
     }
@@ -771,3 +785,29 @@ fn record_batch_into_schema(
 
     RecordBatch::try_new(schema, data_arrays)
 }
+
+/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
+/// It additionally returns a bool, which indicates if there are more record batches available.
+/// We do this so we can determine if we should indicate to the user that the data has been
+/// truncated.
+async fn get_first_record_batch(
+    df: DataFrame,
+) -> Result<(Option<RecordBatch>, bool), DataFusionError> {
+    let mut stream = df.execute_stream().await?;
+    loop {
+        let rb = match stream.next().await {
+            None => return Ok((None, false)),
+            Some(Ok(r)) => r,
+            Some(Err(e)) => return Err(e),
+        };
+
+        if rb.num_rows() > 0 {
+            let has_more = match stream.try_next().await {
+                Ok(None) => false, // reached end
+                Ok(Some(_)) => true,
+                Err(_) => false, // Stream disconnected
+            };
+            return Ok((Some(rb), has_more));
+        }
+    }
+}

From 63074981773da2cd0232b7fb07526e4f3bc7bb9e Mon Sep 17 00:00:00 2001
From: Tim Saucer <timsaucer@gmail.com>
Date: Sat, 1 Mar 2025 15:48:16 +0100
Subject: [PATCH 2/8] Add option to only display a portion of a cell data and
 the user can click on a button to toggle showing more or less

---
 src/dataframe.rs | 80 +++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 76 insertions(+), 4 deletions(-)

diff --git a/src/dataframe.rs b/src/dataframe.rs
index 09fd92a78..211c6a4e7 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -73,6 +73,8 @@ impl PyTableProvider {
     }
 }
 const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
+const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
+const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
 
 /// A PyDataFrame is a representation of a logical plan and an API to compose statements.
 /// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -130,7 +132,37 @@ impl PyDataFrame {
             return Ok("No data to display".to_string());
         };
 
+        let table_uuid = uuid::Uuid::new_v4().to_string();
+
         let mut html_str = "
+        <style>
+            .expandable-container {
+                display: inline-block;
+                max-width: 200px;
+            }
+            .expandable {
+                white-space: nowrap;
+                overflow: hidden;
+                text-overflow: ellipsis;
+                display: block;
+            }
+            .full-text {
+                display: none;
+                white-space: normal;
+            }
+            .expand-btn {
+                cursor: pointer;
+                color: blue;
+                text-decoration: underline;
+                border: none;
+                background: none;
+                font-size: inherit;
+                display: block;
+                margin-top: 5px;
+            }
+        </style>
+
+
         <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
             <table style=\"border-collapse: collapse; min-width: 100%\">
                 <thead>\n".to_string();
@@ -154,17 +186,37 @@ impl PyDataFrame {
         let batch_size = batch.get_array_memory_size();
         let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
             true => {
-                has_more = true;
+                let num_batch_rows = batch.num_rows();
                 let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32;
-                (batch.num_rows() as f32 * ratio).round() as usize
+                let mut reduced_row_num = (num_batch_rows as f32 * ratio).round() as usize;
+                if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY {
+                    reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY.min(num_batch_rows);
+                }
+
+                has_more = has_more || reduced_row_num < num_batch_rows;
+                reduced_row_num
             }
             false => batch.num_rows(),
         };
 
         for row in 0..num_rows_to_display {
             let mut cells = Vec::new();
-            for formatter in &formatters {
-                cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row)));
+            for (col, formatter) in formatters.iter().enumerate() {
+                let cell_data = formatter.value(row).to_string();
+                // From testing, primitive data types do not typically get larger than 21 characters
+                if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
+                    let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
+                    cells.push(format!("
+                        <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
+                            <div class=\"expandable-container\">
+                                <span class=\"expandable\" id=\"{table_uuid}-min-text-{row}-{col}\">{short_cell_data}</span>
+                                <span class=\"full-text\" id=\"{table_uuid}-full-text-{row}-{col}\">{cell_data}</span>
+                                <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{row},{col})\">...</button>
+                            </div>
+                        </td>"));
+                } else {
+                    cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row)));
+                }
             }
             let row_str = cells.join("");
             html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
@@ -172,6 +224,26 @@ impl PyDataFrame {
 
         html_str.push_str("</tbody></table></div>\n");
 
+        html_str.push_str("
+            <script>
+            function toggleDataFrameCellText(table_uuid, row, col) {
+                var shortText = document.getElementById(table_uuid + \"-min-text-\" + row + \"-\" + col);
+                var fullText = document.getElementById(table_uuid + \"-full-text-\" + row + \"-\" + col);
+                var button = event.target;
+
+                if (fullText.style.display === \"none\") {
+                    shortText.style.display = \"none\";
+                    fullText.style.display = \"inline\";
+                    button.textContent = \"(less)\";
+                } else {
+                    shortText.style.display = \"inline\";
+                    fullText.style.display = \"none\";
+                    button.textContent = \"...\";
+                }
+            }
+            </script>
+        ");
+
         if has_more {
             html_str.push_str("Data truncated due to size.");
         }

From f7c1861106f40eccc5fb4ff5f53c11ea30229d70 Mon Sep 17 00:00:00 2001
From: Tim Saucer <timsaucer@gmail.com>
Date: Sat, 1 Mar 2025 16:31:23 +0100
Subject: [PATCH 3/8] We cannot expect that the first non-empy batch is
 sufficient for our 2MB limit, so switch over to collecting until we run out
 or use up the size

---
 src/dataframe.rs | 132 ++++++++++++++++++++++++++++++-----------------
 1 file changed, 85 insertions(+), 47 deletions(-)

diff --git a/src/dataframe.rs b/src/dataframe.rs
index 211c6a4e7..eb071272a 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -126,11 +126,15 @@ impl PyDataFrame {
     }
 
     fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
-        let (batch, mut has_more) =
-            wait_for_future(py, get_first_record_batch(self.df.as_ref().clone()))?;
-        let Some(batch) = batch else {
+        let (batches, mut has_more) =
+            wait_for_future(py, get_first_few_record_batches(self.df.as_ref().clone()))?;
+        let Some(batches) = batches else {
             return Ok("No data to display".to_string());
         };
+        if batches.is_empty() {
+            // This should not be reached, but do it for safety since we index into the vector below
+            return Ok("No data to display".to_string());
+        }
 
         let table_uuid = uuid::Uuid::new_v4().to_string();
 
@@ -162,12 +166,11 @@ impl PyDataFrame {
             }
         </style>
 
-
         <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
             <table style=\"border-collapse: collapse; min-width: 100%\">
                 <thead>\n".to_string();
 
-        let schema = batch.schema();
+        let schema = batches[0].schema();
 
         let mut header = Vec::new();
         for field in schema.fields() {
@@ -176,52 +179,75 @@ impl PyDataFrame {
         let header_str = header.join("");
         html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str));
 
-        let formatters = batch
-            .columns()
+        let batch_formatters = batches
             .iter()
-            .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
-            .map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))))
+            .map(|batch| {
+                batch
+                    .columns()
+                    .iter()
+                    .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
+                    .map(|c| {
+                        c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
+                    })
+                    .collect::<Result<Vec<_>, _>>()
+            })
             .collect::<Result<Vec<_>, _>>()?;
 
-        let batch_size = batch.get_array_memory_size();
-        let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
+        let total_memory: usize = batches
+            .iter()
+            .map(|batch| batch.get_array_memory_size())
+            .sum();
+        let rows_per_batch = batches.iter().map(|batch| batch.num_rows());
+        let total_rows = rows_per_batch.clone().sum();
+
+        // let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| {
+        //     (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows())
+        // });
+
+        let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY {
             true => {
-                let num_batch_rows = batch.num_rows();
-                let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32;
-                let mut reduced_row_num = (num_batch_rows as f32 * ratio).round() as usize;
+                let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32;
+                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
                 if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY {
-                    reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY.min(num_batch_rows);
+                    reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY.min(total_rows);
                 }
 
-                has_more = has_more || reduced_row_num < num_batch_rows;
+                has_more = has_more || reduced_row_num < total_rows;
                 reduced_row_num
             }
-            false => batch.num_rows(),
+            false => total_rows,
         };
 
-        for row in 0..num_rows_to_display {
-            let mut cells = Vec::new();
-            for (col, formatter) in formatters.iter().enumerate() {
-                let cell_data = formatter.value(row).to_string();
-                // From testing, primitive data types do not typically get larger than 21 characters
-                if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
-                    let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
-                    cells.push(format!("
-                        <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
-                            <div class=\"expandable-container\">
-                                <span class=\"expandable\" id=\"{table_uuid}-min-text-{row}-{col}\">{short_cell_data}</span>
-                                <span class=\"full-text\" id=\"{table_uuid}-full-text-{row}-{col}\">{cell_data}</span>
-                                <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{row},{col})\">...</button>
-                            </div>
-                        </td>"));
-                } else {
-                    cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row)));
+        // We need to build up row by row for html
+        let mut table_row = 0;
+        for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) {
+            for batch_row in 0..num_rows_in_batch {
+                table_row += 1;
+                if table_row > num_rows_to_display {
+                    break;
+                }
+                let mut cells = Vec::new();
+                for (col, formatter) in batch_formatter.iter().enumerate() {
+                    let cell_data = formatter.value(batch_row).to_string();
+                    // From testing, primitive data types do not typically get larger than 21 characters
+                    if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
+                        let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
+                        cells.push(format!("
+                            <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
+                                <div class=\"expandable-container\">
+                                    <span class=\"expandable\" id=\"{table_uuid}-min-text-{table_row}-{col}\">{short_cell_data}</span>
+                                    <span class=\"full-text\" id=\"{table_uuid}-full-text-{table_row}-{col}\">{cell_data}</span>
+                                    <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{table_row},{col})\">...</button>
+                                </div>
+                            </td>"));
+                    } else {
+                        cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(batch_row)));
+                    }
                 }
+                let row_str = cells.join("");
+                html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
             }
-            let row_str = cells.join("");
-            html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
         }
-
         html_str.push_str("</tbody></table></div>\n");
 
         html_str.push_str("
@@ -862,24 +888,36 @@ fn record_batch_into_schema(
 /// It additionally returns a bool, which indicates if there are more record batches available.
 /// We do this so we can determine if we should indicate to the user that the data has been
 /// truncated.
-async fn get_first_record_batch(
+async fn get_first_few_record_batches(
     df: DataFrame,
-) -> Result<(Option<RecordBatch>, bool), DataFusionError> {
+) -> Result<(Option<Vec<RecordBatch>>, bool), DataFusionError> {
     let mut stream = df.execute_stream().await?;
-    loop {
+    let mut size_estimate_so_far = 0;
+    let mut record_batches = Vec::default();
+    while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY {
         let rb = match stream.next().await {
-            None => return Ok((None, false)),
+            None => {
+                break;
+            }
             Some(Ok(r)) => r,
             Some(Err(e)) => return Err(e),
         };
 
         if rb.num_rows() > 0 {
-            let has_more = match stream.try_next().await {
-                Ok(None) => false, // reached end
-                Ok(Some(_)) => true,
-                Err(_) => false, // Stream disconnected
-            };
-            return Ok((Some(rb), has_more));
+            size_estimate_so_far += rb.get_array_memory_size();
+            record_batches.push(rb);
         }
     }
+
+    if record_batches.is_empty() {
+        return Ok((None, false));
+    }
+
+    let has_more = match stream.try_next().await {
+        Ok(None) => false, // reached end
+        Ok(Some(_)) => true,
+        Err(_) => false, // Stream disconnected
+    };
+
+    Ok((Some(record_batches), has_more))
 }

From ee3864b8864e195e533425aa5d5c6452fbb1f592 Mon Sep 17 00:00:00 2001
From: Tim Saucer <timsaucer@gmail.com>
Date: Sat, 1 Mar 2025 20:36:04 +0100
Subject: [PATCH 4/8] Update python unit test to allow the additional
 formatting data to exist and only check the table contents

---
 python/tests/test_dataframe.py | 23 ++++++++++++++---------
 1 file changed, 14 insertions(+), 9 deletions(-)

diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index 384b17878..718ebf69d 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import os
+import re
 from typing import Any
 
 import pyarrow as pa
@@ -1245,13 +1246,17 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame:
 def test_dataframe_repr_html(df) -> None:
     output = df._repr_html_()
 
-    ref_html = """<table border='1'>
-        <tr><th>a</td><th>b</td><th>c</td></tr>
-        <tr><td>1</td><td>4</td><td>8</td></tr>
-        <tr><td>2</td><td>5</td><td>5</td></tr>
-        <tr><td>3</td><td>6</td><td>8</td></tr>
-        </table>
-        """
+    # Since we've added a fair bit of processing to the html output, lets just verify
+    # the values we are expecting in the table exist. Use regex and ignore everything
+    # between the <th></th> and <td></td>. We also don't want the closing > on the
+    # td and th segments because that is where the formatting data is written.
 
-    # Ignore whitespace just to make this test look cleaner
-    assert output.replace(" ", "") == ref_html.replace(" ", "")
+    headers = ["a", "b", "c"]
+    headers = [f"<th(.*?)>{v}</th>" for v in headers]
+    header_pattern = "(.*?)".join(headers)
+    assert len(re.findall(header_pattern, output, re.DOTALL)) == 1
+
+    body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]]
+    body_lines = [f"<td(.*?)>{v}</td>" for inner in body_data for v in inner]
+    body_pattern = "(.*?)".join(body_lines)
+    assert len(re.findall(body_pattern, output, re.DOTALL)) == 1

From 933d48ce72b9d9d458f86ce9fa4791bc0740c469 Mon Sep 17 00:00:00 2001
From: Tim Saucer <timsaucer@gmail.com>
Date: Wed, 12 Mar 2025 19:05:39 -0400
Subject: [PATCH 5/8] Combining collection for repr and repr_html into one
 function

---
 src/dataframe.rs | 107 +++++++++++++++++++++++++++++++++++------------
 1 file changed, 80 insertions(+), 27 deletions(-)

diff --git a/src/dataframe.rs b/src/dataframe.rs
index eb071272a..4cf61d23a 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -116,21 +116,37 @@ impl PyDataFrame {
     }
 
     fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
+        let (batches, has_more) = wait_for_future(
+            py,
+            collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
+        )?;
+        if batches.is_empty() {
+            // This should not be reached, but do it for safety since we index into the vector below
+            return Ok("No data to display".to_string());
+        }
+
         let df = self.df.as_ref().clone().limit(0, Some(10))?;
         let batches = wait_for_future(py, df.collect())?;
-        let batches_as_string = pretty::pretty_format_batches(&batches);
-        match batches_as_string {
-            Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
-            Err(err) => Ok(format!("Error: {:?}", err.to_string())),
-        }
+        let batches_as_displ =
+            pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
+
+        let additional_str = match has_more {
+            true => "\nData truncated.",
+            false => "",
+        };
+
+        Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
     }
 
     fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
-        let (batches, mut has_more) =
-            wait_for_future(py, get_first_few_record_batches(self.df.as_ref().clone()))?;
-        let Some(batches) = batches else {
-            return Ok("No data to display".to_string());
-        };
+        let (batches, mut has_more) = wait_for_future(
+            py,
+            collect_record_batches_to_display(
+                self.df.as_ref().clone(),
+                MIN_TABLE_ROWS_TO_DISPLAY,
+                usize::MAX,
+            ),
+        )?;
         if batches.is_empty() {
             // This should not be reached, but do it for safety since we index into the vector below
             return Ok("No data to display".to_string());
@@ -200,10 +216,6 @@ impl PyDataFrame {
         let rows_per_batch = batches.iter().map(|batch| batch.num_rows());
         let total_rows = rows_per_batch.clone().sum();
 
-        // let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| {
-        //     (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows())
-        // });
-
         let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY {
             true => {
                 let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32;
@@ -887,15 +899,28 @@ fn record_batch_into_schema(
 /// This is a helper function to return the first non-empty record batch from executing a DataFrame.
 /// It additionally returns a bool, which indicates if there are more record batches available.
 /// We do this so we can determine if we should indicate to the user that the data has been
-/// truncated.
-async fn get_first_few_record_batches(
+/// truncated. This collects until we have achived both of these two conditions
+///
+/// - We have collected our minimum number of rows
+/// - We have reached our limit, either data size or maximum number of rows
+///
+/// Otherwise it will return when the stream has exhausted. If you want a specific number of
+/// rows, set min_rows == max_rows.
+async fn collect_record_batches_to_display(
     df: DataFrame,
-) -> Result<(Option<Vec<RecordBatch>>, bool), DataFusionError> {
+    min_rows: usize,
+    max_rows: usize,
+) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
     let mut stream = df.execute_stream().await?;
     let mut size_estimate_so_far = 0;
+    let mut rows_so_far = 0;
     let mut record_batches = Vec::default();
-    while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY {
-        let rb = match stream.next().await {
+    let mut has_more = false;
+
+    while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
+        || rows_so_far < min_rows
+    {
+        let mut rb = match stream.next().await {
             None => {
                 break;
             }
@@ -903,21 +928,49 @@ async fn get_first_few_record_batches(
             Some(Err(e)) => return Err(e),
         };
 
-        if rb.num_rows() > 0 {
+        let mut rows_in_rb = rb.num_rows();
+        if rows_in_rb > 0 {
             size_estimate_so_far += rb.get_array_memory_size();
+
+            if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
+                let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
+                let total_rows = rows_in_rb + rows_so_far;
+
+                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
+                if reduced_row_num < min_rows {
+                    reduced_row_num = min_rows.min(total_rows);
+                }
+
+                let limited_rows_this_rb = reduced_row_num - rows_so_far;
+                if limited_rows_this_rb < rows_in_rb {
+                    rows_in_rb = limited_rows_this_rb;
+                    rb = rb.slice(0, limited_rows_this_rb);
+                    has_more = true;
+                }
+            }
+
+            if rows_in_rb + rows_so_far > max_rows {
+                rb = rb.slice(0, max_rows - rows_so_far);
+                has_more = true;
+            }
+
+            rows_so_far += rb.num_rows();
             record_batches.push(rb);
         }
     }
 
     if record_batches.is_empty() {
-        return Ok((None, false));
+        return Ok((Vec::default(), false));
     }
 
-    let has_more = match stream.try_next().await {
-        Ok(None) => false, // reached end
-        Ok(Some(_)) => true,
-        Err(_) => false, // Stream disconnected
-    };
+    if !has_more {
+        // Data was not already truncated, so check to see if more record batches remain
+        has_more = match stream.try_next().await {
+            Ok(None) => false, // reached end
+            Ok(Some(_)) => true,
+            Err(_) => false, // Stream disconnected
+        };
+    }
 
-    Ok((Some(record_batches), has_more))
+    Ok((record_batches, has_more))
 }

From 3ba4d9d395f92b320e7e31b283721e37c6360ee9 Mon Sep 17 00:00:00 2001
From: Tim Saucer <timsaucer@gmail.com>
Date: Wed, 12 Mar 2025 19:06:24 -0400
Subject: [PATCH 6/8] Small clippy suggestion

---
 src/utils.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/utils.rs b/src/utils.rs
index 999aad755..3487de21b 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -42,7 +42,7 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
 #[inline]
 pub(crate) fn get_global_ctx() -> &'static SessionContext {
     static CTX: OnceLock<SessionContext> = OnceLock::new();
-    CTX.get_or_init(|| SessionContext::new())
+    CTX.get_or_init(SessionContext::new)
 }
 
 /// Utility to collect rust futures with GIL released

From cacba9b471d63ff63dc0292ae9fd5389a331bc54 Mon Sep 17 00:00:00 2001
From: Tim Saucer <timsaucer@gmail.com>
Date: Thu, 13 Mar 2025 07:36:16 -0400
Subject: [PATCH 7/8] Collect was occuring twice on repr

---
 src/dataframe.rs | 26 +-------------------------
 1 file changed, 1 insertion(+), 25 deletions(-)

diff --git a/src/dataframe.rs b/src/dataframe.rs
index 4cf61d23a..49fd08642 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -125,8 +125,6 @@ impl PyDataFrame {
             return Ok("No data to display".to_string());
         }
 
-        let df = self.df.as_ref().clone().limit(0, Some(10))?;
-        let batches = wait_for_future(py, df.collect())?;
         let batches_as_displ =
             pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
 
@@ -139,7 +137,7 @@ impl PyDataFrame {
     }
 
     fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
-        let (batches, mut has_more) = wait_for_future(
+        let (batches, has_more) = wait_for_future(
             py,
             collect_record_batches_to_display(
                 self.df.as_ref().clone(),
@@ -209,35 +207,13 @@ impl PyDataFrame {
             })
             .collect::<Result<Vec<_>, _>>()?;
 
-        let total_memory: usize = batches
-            .iter()
-            .map(|batch| batch.get_array_memory_size())
-            .sum();
         let rows_per_batch = batches.iter().map(|batch| batch.num_rows());
-        let total_rows = rows_per_batch.clone().sum();
-
-        let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY {
-            true => {
-                let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32;
-                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
-                if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY {
-                    reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY.min(total_rows);
-                }
-
-                has_more = has_more || reduced_row_num < total_rows;
-                reduced_row_num
-            }
-            false => total_rows,
-        };
 
         // We need to build up row by row for html
         let mut table_row = 0;
         for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) {
             for batch_row in 0..num_rows_in_batch {
                 table_row += 1;
-                if table_row > num_rows_to_display {
-                    break;
-                }
                 let mut cells = Vec::new();
                 for (col, formatter) in batch_formatter.iter().enumerate() {
                     let cell_data = formatter.value(batch_row).to_string();

From 2882050e95c6baa5b0e4250786e6fd537685d1e8 Mon Sep 17 00:00:00 2001
From: Tim Saucer <timsaucer@gmail.com>
Date: Sat, 22 Mar 2025 08:57:24 -0400
Subject: [PATCH 8/8] Switch to execute_stream_partitioned

---
 src/dataframe.rs | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/dataframe.rs b/src/dataframe.rs
index 49fd08642..be10b8c28 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -887,7 +887,8 @@ async fn collect_record_batches_to_display(
     min_rows: usize,
     max_rows: usize,
 ) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
-    let mut stream = df.execute_stream().await?;
+    let partitioned_stream = df.execute_stream_partitioned().await?;
+    let mut stream = futures::stream::iter(partitioned_stream).flatten();
     let mut size_estimate_so_far = 0;
     let mut rows_so_far = 0;
     let mut record_batches = Vec::default();