@@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions;
31
31
use datafusion:: config:: { CsvOptions , TableParquetOptions } ;
32
32
use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
33
33
use datafusion:: datasource:: TableProvider ;
34
+ use datafusion:: error:: DataFusionError ;
34
35
use datafusion:: execution:: SendableRecordBatchStream ;
35
36
use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
36
37
use datafusion:: prelude:: * ;
38
+ use futures:: { StreamExt , TryStreamExt } ;
37
39
use pyo3:: exceptions:: PyValueError ;
38
40
use pyo3:: prelude:: * ;
39
41
use pyo3:: pybacked:: PyBackedStr ;
@@ -70,6 +72,7 @@ impl PyTableProvider {
70
72
PyTable :: new ( table_provider)
71
73
}
72
74
}
75
+ const MAX_TABLE_BYTES_TO_DISPLAY : usize = 2 * 1024 * 1024 ; // 2 MB
73
76
74
77
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
75
78
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -121,46 +124,57 @@ impl PyDataFrame {
121
124
}
122
125
123
126
fn _repr_html_ ( & self , py : Python ) -> PyDataFusionResult < String > {
124
- let mut html_str = "<table border='1'>\n " . to_string ( ) ;
125
-
126
- let df = self . df . as_ref ( ) . clone ( ) . limit ( 0 , Some ( 10 ) ) ?;
127
- let batches = wait_for_future ( py, df. collect ( ) ) ?;
127
+ let ( batch, mut has_more) =
128
+ wait_for_future ( py, get_first_record_batch ( self . df . as_ref ( ) . clone ( ) ) ) ?;
129
+ let Some ( batch) = batch else {
130
+ return Ok ( "No data to display" . to_string ( ) ) ;
131
+ } ;
128
132
129
- if batches . is_empty ( ) {
130
- html_str . push_str ( "</table> \n " ) ;
131
- return Ok ( html_str ) ;
132
- }
133
+ let mut html_str = "
134
+ <div style= \" width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc; \" >
135
+ <table style= \" border-collapse: collapse; min-width: 100% \" >
136
+ <thead> \n " . to_string ( ) ;
133
137
134
- let schema = batches [ 0 ] . schema ( ) ;
138
+ let schema = batch . schema ( ) ;
135
139
136
140
let mut header = Vec :: new ( ) ;
137
141
for field in schema. fields ( ) {
138
- header. push ( format ! ( "<th>{}</td >" , field. name( ) ) ) ;
142
+ header. push ( format ! ( "<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;' >{}</th >" , field. name( ) ) ) ;
139
143
}
140
144
let header_str = header. join ( "" ) ;
141
- html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , header_str) ) ;
145
+ html_str. push_str ( & format ! ( "<tr>{}</tr></thead><tbody>\n " , header_str) ) ;
146
+
147
+ let formatters = batch
148
+ . columns ( )
149
+ . iter ( )
150
+ . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
151
+ . map ( |c| c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) ) )
152
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
153
+
154
+ let batch_size = batch. get_array_memory_size ( ) ;
155
+ let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
156
+ true => {
157
+ has_more = true ;
158
+ let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32 ;
159
+ ( batch. num_rows ( ) as f32 * ratio) . round ( ) as usize
160
+ }
161
+ false => batch. num_rows ( ) ,
162
+ } ;
142
163
143
- for batch in batches {
144
- let formatters = batch
145
- . columns ( )
146
- . iter ( )
147
- . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
148
- . map ( |c| {
149
- c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) )
150
- } )
151
- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
152
-
153
- for row in 0 ..batch. num_rows ( ) {
154
- let mut cells = Vec :: new ( ) ;
155
- for formatter in & formatters {
156
- cells. push ( format ! ( "<td>{}</td>" , formatter. value( row) ) ) ;
157
- }
158
- let row_str = cells. join ( "" ) ;
159
- html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
164
+ for row in 0 ..num_rows_to_display {
165
+ let mut cells = Vec :: new ( ) ;
166
+ for formatter in & formatters {
167
+ cells. push ( format ! ( "<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>" , formatter. value( row) ) ) ;
160
168
}
169
+ let row_str = cells. join ( "" ) ;
170
+ html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
161
171
}
162
172
163
- html_str. push_str ( "</table>\n " ) ;
173
+ html_str. push_str ( "</tbody></table></div>\n " ) ;
174
+
175
+ if has_more {
176
+ html_str. push_str ( "Data truncated due to size." ) ;
177
+ }
164
178
165
179
Ok ( html_str)
166
180
}
@@ -771,3 +785,29 @@ fn record_batch_into_schema(
771
785
772
786
RecordBatch :: try_new ( schema, data_arrays)
773
787
}
788
+
789
+ /// This is a helper function to return the first non-empty record batch from executing a DataFrame.
790
+ /// It additionally returns a bool, which indicates if there are more record batches available.
791
+ /// We do this so we can determine if we should indicate to the user that the data has been
792
+ /// truncated.
793
+ async fn get_first_record_batch (
794
+ df : DataFrame ,
795
+ ) -> Result < ( Option < RecordBatch > , bool ) , DataFusionError > {
796
+ let mut stream = df. execute_stream ( ) . await ?;
797
+ loop {
798
+ let rb = match stream. next ( ) . await {
799
+ None => return Ok ( ( None , false ) ) ,
800
+ Some ( Ok ( r) ) => r,
801
+ Some ( Err ( e) ) => return Err ( e) ,
802
+ } ;
803
+
804
+ if rb. num_rows ( ) > 0 {
805
+ let has_more = match stream. try_next ( ) . await {
806
+ Ok ( None ) => false , // reached end
807
+ Ok ( Some ( _) ) => true ,
808
+ Err ( _) => false , // Stream disconnected
809
+ } ;
810
+ return Ok ( ( Some ( rb) , has_more) ) ;
811
+ }
812
+ }
813
+ }
0 commit comments