Skip to content

feat: collect once during display() in jupyter notebooks #1167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
use crate::utils::{
get_tokio_runtime, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
};
use crate::{
errors::PyDataFusionResult,
Expand Down Expand Up @@ -289,21 +289,33 @@ impl PyParquetColumnOptions {
#[derive(Clone)]
pub struct PyDataFrame {
df: Arc<DataFrame>,

// In IPython environment cache batches between __repr__ and _repr_html_ calls.
batches: Option<(Vec<RecordBatch>, bool)>,
}

impl PyDataFrame {
/// creates a new PyDataFrame
pub fn new(df: DataFrame) -> Self {
Self { df: Arc::new(df) }
Self {
df: Arc::new(df),
batches: None,
}
}

fn prepare_repr_string(&self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
// Get the Python formatter and config
let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
let (batches, has_more) = wait_for_future(
py,
collect_record_batches_to_display(self.df.as_ref().clone(), config),
)??;

let should_cache = *is_ipython_env(py) && self.batches.is_none();
let (batches, has_more) = match self.batches.take() {
Some(b) => b,
None => wait_for_future(
py,
collect_record_batches_to_display(self.df.as_ref().clone(), config),
)??,
};

if batches.is_empty() {
// This should not be reached, but do it for safety since we index into the vector below
return Ok("No data to display".to_string());
Expand All @@ -313,7 +325,7 @@ impl PyDataFrame {

// Convert record batches to PyObject list
let py_batches = batches
.into_iter()
.iter()
.map(|rb| rb.to_pyarrow(py))
.collect::<PyResult<Vec<PyObject>>>()?;

Expand All @@ -334,6 +346,10 @@ impl PyDataFrame {
let html_result = formatter.call_method(method_name, (), Some(&kwargs))?;
let html_str: String = html_result.extract()?;

if should_cache {
self.batches = Some((batches, has_more));
}

Ok(html_str)
}
}
Expand Down Expand Up @@ -361,7 +377,7 @@ impl PyDataFrame {
}
}

fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
fn __repr__(&mut self, py: Python) -> PyDataFusionResult<String> {
self.prepare_repr_string(py, false)
}

Expand Down Expand Up @@ -396,7 +412,7 @@ impl PyDataFrame {
Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
}

fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult<String> {
self.prepare_repr_string(py, true)
}

Expand Down
11 changes: 11 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
}

#[inline]
pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
IS_IPYTHON_ENV.get_or_init(|| {
py.import("IPython")
.and_then(|ipython| ipython.call_method0("get_ipython"))
.map(|ipython| !ipython.is_none())
.unwrap_or(false)
})
}

/// Utility to get the Global Datafussion CTX
#[inline]
pub(crate) fn get_global_ctx() -> &'static SessionContext {
Expand Down