Skip to content

Commit 893c092

Browse files
committed
Improve table readout of a dataframe in jupyter notebooks by making the table scrollable and displaying the first record batch up to 2MB
1 parent b194a87 commit 893c092

File tree

1 file changed

+69
-29
lines changed

1 file changed

+69
-29
lines changed

src/dataframe.rs

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions;
3131
use datafusion::config::{CsvOptions, TableParquetOptions};
3232
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3333
use datafusion::datasource::TableProvider;
34+
use datafusion::error::DataFusionError;
3435
use datafusion::execution::SendableRecordBatchStream;
3536
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
3637
use datafusion::prelude::*;
38+
use futures::{StreamExt, TryStreamExt};
3739
use pyo3::exceptions::PyValueError;
3840
use pyo3::prelude::*;
3941
use pyo3::pybacked::PyBackedStr;
@@ -70,6 +72,7 @@ impl PyTableProvider {
7072
PyTable::new(table_provider)
7173
}
7274
}
75+
const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
7376

7477
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
7578
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -121,46 +124,57 @@ impl PyDataFrame {
121124
}
122125

123126
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
124-
let mut html_str = "<table border='1'>\n".to_string();
125-
126-
let df = self.df.as_ref().clone().limit(0, Some(10))?;
127-
let batches = wait_for_future(py, df.collect())?;
127+
let (batch, mut has_more) =
128+
wait_for_future(py, get_first_record_batch(self.df.as_ref().clone()))?;
129+
let Some(batch) = batch else {
130+
return Ok("No data to display".to_string());
131+
};
128132

129-
if batches.is_empty() {
130-
html_str.push_str("</table>\n");
131-
return Ok(html_str);
132-
}
133+
let mut html_str = "
134+
<div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
135+
<table style=\"border-collapse: collapse; min-width: 100%\">
136+
<thead>\n".to_string();
133137

134-
let schema = batches[0].schema();
138+
let schema = batch.schema();
135139

136140
let mut header = Vec::new();
137141
for field in schema.fields() {
138-
header.push(format!("<th>{}</td>", field.name()));
142+
header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name()));
139143
}
140144
let header_str = header.join("");
141-
html_str.push_str(&format!("<tr>{}</tr>\n", header_str));
145+
html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str));
146+
147+
let formatters = batch
148+
.columns()
149+
.iter()
150+
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
151+
.map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))))
152+
.collect::<Result<Vec<_>, _>>()?;
153+
154+
let batch_size = batch.get_array_memory_size();
155+
let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
156+
true => {
157+
has_more = true;
158+
let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32;
159+
(batch.num_rows() as f32 * ratio).round() as usize
160+
}
161+
false => batch.num_rows(),
162+
};
142163

143-
for batch in batches {
144-
let formatters = batch
145-
.columns()
146-
.iter()
147-
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
148-
.map(|c| {
149-
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
150-
})
151-
.collect::<Result<Vec<_>, _>>()?;
152-
153-
for row in 0..batch.num_rows() {
154-
let mut cells = Vec::new();
155-
for formatter in &formatters {
156-
cells.push(format!("<td>{}</td>", formatter.value(row)));
157-
}
158-
let row_str = cells.join("");
159-
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
164+
for row in 0..num_rows_to_display {
165+
let mut cells = Vec::new();
166+
for formatter in &formatters {
167+
cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row)));
160168
}
169+
let row_str = cells.join("");
170+
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
161171
}
162172

163-
html_str.push_str("</table>\n");
173+
html_str.push_str("</tbody></table></div>\n");
174+
175+
if has_more {
176+
html_str.push_str("Data truncated due to size.");
177+
}
164178

165179
Ok(html_str)
166180
}
@@ -771,3 +785,29 @@ fn record_batch_into_schema(
771785

772786
RecordBatch::try_new(schema, data_arrays)
773787
}
788+
789+
/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
790+
/// It additionally returns a bool, which indicates if there are more record batches available.
791+
/// We do this so we can determine if we should indicate to the user that the data has been
792+
/// truncated.
793+
async fn get_first_record_batch(
794+
df: DataFrame,
795+
) -> Result<(Option<RecordBatch>, bool), DataFusionError> {
796+
let mut stream = df.execute_stream().await?;
797+
loop {
798+
let rb = match stream.next().await {
799+
None => return Ok((None, false)),
800+
Some(Ok(r)) => r,
801+
Some(Err(e)) => return Err(e),
802+
};
803+
804+
if rb.num_rows() > 0 {
805+
let has_more = match stream.try_next().await {
806+
Ok(None) => false, // reached end
807+
Ok(Some(_)) => true,
808+
Err(_) => false, // Stream disconnected
809+
};
810+
return Ok((Some(rb), has_more));
811+
}
812+
}
813+
}

0 commit comments

Comments
 (0)