diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 384b17878..718ebf69d 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import os +import re from typing import Any import pyarrow as pa @@ -1245,13 +1246,17 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame: def test_dataframe_repr_html(df) -> None: output = df._repr_html_() - ref_html = """<table border='1'> - <tr><th>a</td><th>b</td><th>c</td></tr> - <tr><td>1</td><td>4</td><td>8</td></tr> - <tr><td>2</td><td>5</td><td>5</td></tr> - <tr><td>3</td><td>6</td><td>8</td></tr> - </table> - """ + # Since we've added a fair bit of processing to the html output, lets just verify + # the values we are expecting in the table exist. Use regex and ignore everything + # between the <th></th> and <td></td>. We also don't want the closing > on the + # td and th segments because that is where the formatting data is written. - # Ignore whitespace just to make this test look cleaner - assert output.replace(" ", "") == ref_html.replace(" ", "") + headers = ["a", "b", "c"] + headers = [f"<th(.*?)>{v}</th>" for v in headers] + header_pattern = "(.*?)".join(headers) + assert len(re.findall(header_pattern, output, re.DOTALL)) == 1 + + body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]] + body_lines = [f"<td(.*?)>{v}</td>" for inner in body_data for v in inner] + body_pattern = "(.*?)".join(body_lines) + assert len(re.findall(body_pattern, output, re.DOTALL)) == 1 diff --git a/src/dataframe.rs b/src/dataframe.rs index 243e2e14f..be10b8c28 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions; use datafusion::config::{CsvOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::TableProvider; +use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; +use futures::{StreamExt, TryStreamExt}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -70,6 +72,9 @@ impl PyTableProvider { PyTable::new(table_provider) } } +const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB +const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20; +const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. @@ -111,56 +116,151 @@ impl PyDataFrame { } fn __repr__(&self, py: Python) -> PyDataFusionResult<String> { - let df = self.df.as_ref().clone().limit(0, Some(10))?; - let batches = wait_for_future(py, df.collect())?; - let batches_as_string = pretty::pretty_format_batches(&batches); - match batches_as_string { - Ok(batch) => Ok(format!("DataFrame()\n{batch}")), - Err(err) => Ok(format!("Error: {:?}", err.to_string())), + let (batches, has_more) = wait_for_future( + py, + collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10), + )?; + if batches.is_empty() { + // This should not be reached, but do it for safety since we index into the vector below + return Ok("No data to display".to_string()); } - } - fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> { - let mut html_str = "<table border='1'>\n".to_string(); + let batches_as_displ = + pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?; + + let additional_str = match has_more { + true => "\nData truncated.", + false => "", + }; - let df = self.df.as_ref().clone().limit(0, Some(10))?; - let batches = wait_for_future(py, df.collect())?; + Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) + } + fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> { + let (batches, has_more) = wait_for_future( + py, + collect_record_batches_to_display( + self.df.as_ref().clone(), + MIN_TABLE_ROWS_TO_DISPLAY, + usize::MAX, + ), + )?; if batches.is_empty() { - html_str.push_str("</table>\n"); - return Ok(html_str); + // This should not be reached, but do it for safety since we index into the vector below + return Ok("No data to display".to_string()); } + let table_uuid = uuid::Uuid::new_v4().to_string(); + + let mut html_str = " + <style> + .expandable-container { + display: inline-block; + max-width: 200px; + } + .expandable { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + display: block; + } + .full-text { + display: none; + white-space: normal; + } + .expand-btn { + cursor: pointer; + color: blue; + text-decoration: underline; + border: none; + background: none; + font-size: inherit; + display: block; + margin-top: 5px; + } + </style> + + <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\"> + <table style=\"border-collapse: collapse; min-width: 100%\"> + <thead>\n".to_string(); + let schema = batches[0].schema(); let mut header = Vec::new(); for field in schema.fields() { - header.push(format!("<th>{}</td>", field.name())); + header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name())); } let header_str = header.join(""); - html_str.push_str(&format!("<tr>{}</tr>\n", header_str)); - - for batch in batches { - let formatters = batch - .columns() - .iter() - .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default())) - .map(|c| { - c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))) - }) - .collect::<Result<Vec<_>, _>>()?; - - for row in 0..batch.num_rows() { + html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str)); + + let batch_formatters = batches + .iter() + .map(|batch| { + batch + .columns() + .iter() + .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default())) + .map(|c| { + c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))) + }) + .collect::<Result<Vec<_>, _>>() + }) + .collect::<Result<Vec<_>, _>>()?; + + let rows_per_batch = batches.iter().map(|batch| batch.num_rows()); + + // We need to build up row by row for html + let mut table_row = 0; + for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) { + for batch_row in 0..num_rows_in_batch { + table_row += 1; let mut cells = Vec::new(); - for formatter in &formatters { - cells.push(format!("<td>{}</td>", formatter.value(row))); + for (col, formatter) in batch_formatter.iter().enumerate() { + let cell_data = formatter.value(batch_row).to_string(); + // From testing, primitive data types do not typically get larger than 21 characters + if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE { + let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE]; + cells.push(format!(" + <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'> + <div class=\"expandable-container\"> + <span class=\"expandable\" id=\"{table_uuid}-min-text-{table_row}-{col}\">{short_cell_data}</span> + <span class=\"full-text\" id=\"{table_uuid}-full-text-{table_row}-{col}\">{cell_data}</span> + <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{table_row},{col})\">...</button> + </div> + </td>")); + } else { + cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(batch_row))); + } } let row_str = cells.join(""); html_str.push_str(&format!("<tr>{}</tr>\n", row_str)); } } + html_str.push_str("</tbody></table></div>\n"); + + html_str.push_str(" + <script> + function toggleDataFrameCellText(table_uuid, row, col) { + var shortText = document.getElementById(table_uuid + \"-min-text-\" + row + \"-\" + col); + var fullText = document.getElementById(table_uuid + \"-full-text-\" + row + \"-\" + col); + var button = event.target; + + if (fullText.style.display === \"none\") { + shortText.style.display = \"none\"; + fullText.style.display = \"inline\"; + button.textContent = \"(less)\"; + } else { + shortText.style.display = \"inline\"; + fullText.style.display = \"none\"; + button.textContent = \"...\"; + } + } + </script> + "); - html_str.push_str("</table>\n"); + if has_more { + html_str.push_str("Data truncated due to size."); + } Ok(html_str) } @@ -771,3 +871,83 @@ fn record_batch_into_schema( RecordBatch::try_new(schema, data_arrays) } + +/// This is a helper function to return the first non-empty record batch from executing a DataFrame. +/// It additionally returns a bool, which indicates if there are more record batches available. +/// We do this so we can determine if we should indicate to the user that the data has been +/// truncated. This collects until we have achived both of these two conditions +/// +/// - We have collected our minimum number of rows +/// - We have reached our limit, either data size or maximum number of rows +/// +/// Otherwise it will return when the stream has exhausted. If you want a specific number of +/// rows, set min_rows == max_rows. +async fn collect_record_batches_to_display( + df: DataFrame, + min_rows: usize, + max_rows: usize, +) -> Result<(Vec<RecordBatch>, bool), DataFusionError> { + let partitioned_stream = df.execute_stream_partitioned().await?; + let mut stream = futures::stream::iter(partitioned_stream).flatten(); + let mut size_estimate_so_far = 0; + let mut rows_so_far = 0; + let mut record_batches = Vec::default(); + let mut has_more = false; + + while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows) + || rows_so_far < min_rows + { + let mut rb = match stream.next().await { + None => { + break; + } + Some(Ok(r)) => r, + Some(Err(e)) => return Err(e), + }; + + let mut rows_in_rb = rb.num_rows(); + if rows_in_rb > 0 { + size_estimate_so_far += rb.get_array_memory_size(); + + if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY { + let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32; + let total_rows = rows_in_rb + rows_so_far; + + let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize; + if reduced_row_num < min_rows { + reduced_row_num = min_rows.min(total_rows); + } + + let limited_rows_this_rb = reduced_row_num - rows_so_far; + if limited_rows_this_rb < rows_in_rb { + rows_in_rb = limited_rows_this_rb; + rb = rb.slice(0, limited_rows_this_rb); + has_more = true; + } + } + + if rows_in_rb + rows_so_far > max_rows { + rb = rb.slice(0, max_rows - rows_so_far); + has_more = true; + } + + rows_so_far += rb.num_rows(); + record_batches.push(rb); + } + } + + if record_batches.is_empty() { + return Ok((Vec::default(), false)); + } + + if !has_more { + // Data was not already truncated, so check to see if more record batches remain + has_more = match stream.try_next().await { + Ok(None) => false, // reached end + Ok(Some(_)) => true, + Err(_) => false, // Stream disconnected + }; + } + + Ok((record_batches, has_more)) +} diff --git a/src/utils.rs b/src/utils.rs index 999aad755..3487de21b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -42,7 +42,7 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { #[inline] pub(crate) fn get_global_ctx() -> &'static SessionContext { static CTX: OnceLock<SessionContext> = OnceLock::new(); - CTX.get_or_init(|| SessionContext::new()) + CTX.get_or_init(SessionContext::new) } /// Utility to collect rust futures with GIL released