@@ -126,11 +126,15 @@ impl PyDataFrame {
126
126
}
127
127
128
128
fn _repr_html_ ( & self , py : Python ) -> PyDataFusionResult < String > {
129
- let ( batch , mut has_more) =
130
- wait_for_future ( py, get_first_record_batch ( self . df . as_ref ( ) . clone ( ) ) ) ?;
131
- let Some ( batch ) = batch else {
129
+ let ( batches , mut has_more) =
130
+ wait_for_future ( py, get_first_few_record_batches ( self . df . as_ref ( ) . clone ( ) ) ) ?;
131
+ let Some ( batches ) = batches else {
132
132
return Ok ( "No data to display" . to_string ( ) ) ;
133
133
} ;
134
+ if batches. is_empty ( ) {
135
+ // This should not be reached, but do it for safety since we index into the vector below
136
+ return Ok ( "No data to display" . to_string ( ) ) ;
137
+ }
134
138
135
139
let table_uuid = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
136
140
@@ -162,12 +166,11 @@ impl PyDataFrame {
162
166
}
163
167
</style>
164
168
165
-
166
169
<div style=\" width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\" >
167
170
<table style=\" border-collapse: collapse; min-width: 100%\" >
168
171
<thead>\n " . to_string ( ) ;
169
172
170
- let schema = batch . schema ( ) ;
173
+ let schema = batches [ 0 ] . schema ( ) ;
171
174
172
175
let mut header = Vec :: new ( ) ;
173
176
for field in schema. fields ( ) {
@@ -176,52 +179,75 @@ impl PyDataFrame {
176
179
let header_str = header. join ( "" ) ;
177
180
html_str. push_str ( & format ! ( "<tr>{}</tr></thead><tbody>\n " , header_str) ) ;
178
181
179
- let formatters = batch
180
- . columns ( )
182
+ let batch_formatters = batches
181
183
. iter ( )
182
- . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
183
- . map ( |c| c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) ) )
184
+ . map ( |batch| {
185
+ batch
186
+ . columns ( )
187
+ . iter ( )
188
+ . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
189
+ . map ( |c| {
190
+ c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) )
191
+ } )
192
+ . collect :: < Result < Vec < _ > , _ > > ( )
193
+ } )
184
194
. collect :: < Result < Vec < _ > , _ > > ( ) ?;
185
195
186
- let batch_size = batch. get_array_memory_size ( ) ;
187
- let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
196
+ let total_memory: usize = batches
197
+ . iter ( )
198
+ . map ( |batch| batch. get_array_memory_size ( ) )
199
+ . sum ( ) ;
200
+ let rows_per_batch = batches. iter ( ) . map ( |batch| batch. num_rows ( ) ) ;
201
+ let total_rows = rows_per_batch. clone ( ) . sum ( ) ;
202
+
203
+ // let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| {
204
+ // (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows())
205
+ // });
206
+
207
+ let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY {
188
208
true => {
189
- let num_batch_rows = batch. num_rows ( ) ;
190
- let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32 ;
191
- let mut reduced_row_num = ( num_batch_rows as f32 * ratio) . round ( ) as usize ;
209
+ let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32 ;
210
+ let mut reduced_row_num = ( total_rows as f32 * ratio) . round ( ) as usize ;
192
211
if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY {
193
- reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY . min ( num_batch_rows ) ;
212
+ reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY . min ( total_rows ) ;
194
213
}
195
214
196
- has_more = has_more || reduced_row_num < num_batch_rows ;
215
+ has_more = has_more || reduced_row_num < total_rows ;
197
216
reduced_row_num
198
217
}
199
- false => batch . num_rows ( ) ,
218
+ false => total_rows ,
200
219
} ;
201
220
202
- for row in 0 ..num_rows_to_display {
203
- let mut cells = Vec :: new ( ) ;
204
- for ( col, formatter) in formatters. iter ( ) . enumerate ( ) {
205
- let cell_data = formatter. value ( row) . to_string ( ) ;
206
- // From testing, primitive data types do not typically get larger than 21 characters
207
- if cell_data. len ( ) > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
208
- let short_cell_data = & cell_data[ 0 ..MAX_LENGTH_CELL_WITHOUT_MINIMIZE ] ;
209
- cells. push ( format ! ( "
210
- <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
211
- <div class=\" expandable-container\" >
212
- <span class=\" expandable\" id=\" {table_uuid}-min-text-{row}-{col}\" >{short_cell_data}</span>
213
- <span class=\" full-text\" id=\" {table_uuid}-full-text-{row}-{col}\" >{cell_data}</span>
214
- <button class=\" expand-btn\" onclick=\" toggleDataFrameCellText('{table_uuid}',{row},{col})\" >...</button>
215
- </div>
216
- </td>" ) ) ;
217
- } else {
218
- cells. push ( format ! ( "<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>" , formatter. value( row) ) ) ;
221
+ // We need to build up row by row for html
222
+ let mut table_row = 0 ;
223
+ for ( batch_formatter, num_rows_in_batch) in batch_formatters. iter ( ) . zip ( rows_per_batch) {
224
+ for batch_row in 0 ..num_rows_in_batch {
225
+ table_row += 1 ;
226
+ if table_row > num_rows_to_display {
227
+ break ;
228
+ }
229
+ let mut cells = Vec :: new ( ) ;
230
+ for ( col, formatter) in batch_formatter. iter ( ) . enumerate ( ) {
231
+ let cell_data = formatter. value ( batch_row) . to_string ( ) ;
232
+ // From testing, primitive data types do not typically get larger than 21 characters
233
+ if cell_data. len ( ) > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
234
+ let short_cell_data = & cell_data[ 0 ..MAX_LENGTH_CELL_WITHOUT_MINIMIZE ] ;
235
+ cells. push ( format ! ( "
236
+ <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
237
+ <div class=\" expandable-container\" >
238
+ <span class=\" expandable\" id=\" {table_uuid}-min-text-{table_row}-{col}\" >{short_cell_data}</span>
239
+ <span class=\" full-text\" id=\" {table_uuid}-full-text-{table_row}-{col}\" >{cell_data}</span>
240
+ <button class=\" expand-btn\" onclick=\" toggleDataFrameCellText('{table_uuid}',{table_row},{col})\" >...</button>
241
+ </div>
242
+ </td>" ) ) ;
243
+ } else {
244
+ cells. push ( format ! ( "<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>" , formatter. value( batch_row) ) ) ;
245
+ }
219
246
}
247
+ let row_str = cells. join ( "" ) ;
248
+ html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
220
249
}
221
- let row_str = cells. join ( "" ) ;
222
- html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
223
250
}
224
-
225
251
html_str. push_str ( "</tbody></table></div>\n " ) ;
226
252
227
253
html_str. push_str ( "
@@ -862,24 +888,36 @@ fn record_batch_into_schema(
862
888
/// It additionally returns a bool, which indicates if there are more record batches available.
863
889
/// We do this so we can determine if we should indicate to the user that the data has been
864
890
/// truncated.
865
- async fn get_first_record_batch (
891
+ async fn get_first_few_record_batches (
866
892
df : DataFrame ,
867
- ) -> Result < ( Option < RecordBatch > , bool ) , DataFusionError > {
893
+ ) -> Result < ( Option < Vec < RecordBatch > > , bool ) , DataFusionError > {
868
894
let mut stream = df. execute_stream ( ) . await ?;
869
- loop {
895
+ let mut size_estimate_so_far = 0 ;
896
+ let mut record_batches = Vec :: default ( ) ;
897
+ while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY {
870
898
let rb = match stream. next ( ) . await {
871
- None => return Ok ( ( None , false ) ) ,
899
+ None => {
900
+ break ;
901
+ }
872
902
Some ( Ok ( r) ) => r,
873
903
Some ( Err ( e) ) => return Err ( e) ,
874
904
} ;
875
905
876
906
if rb. num_rows ( ) > 0 {
877
- let has_more = match stream. try_next ( ) . await {
878
- Ok ( None ) => false , // reached end
879
- Ok ( Some ( _) ) => true ,
880
- Err ( _) => false , // Stream disconnected
881
- } ;
882
- return Ok ( ( Some ( rb) , has_more) ) ;
907
+ size_estimate_so_far += rb. get_array_memory_size ( ) ;
908
+ record_batches. push ( rb) ;
883
909
}
884
910
}
911
+
912
+ if record_batches. is_empty ( ) {
913
+ return Ok ( ( None , false ) ) ;
914
+ }
915
+
916
+ let has_more = match stream. try_next ( ) . await {
917
+ Ok ( None ) => false , // reached end
918
+ Ok ( Some ( _) ) => true ,
919
+ Err ( _) => false , // Stream disconnected
920
+ } ;
921
+
922
+ Ok ( ( Some ( record_batches) , has_more) )
885
923
}
0 commit comments