Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions crates/polars-plan/src/dsl/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::BitAnd;

use super::*;
use crate::plans::conversion::is_regex_projection;
use crate::plans::ir::tree_format::TreeFmtVisitor;
use crate::plans::tree_format::TreeFmtVisitor;
use crate::plans::visitor::{AexprNode, TreeWalker};
use crate::prelude::tree_format::TreeFmtVisitorDisplay;

Expand All @@ -12,7 +12,7 @@ pub struct MetaNameSpace(pub(crate) Expr);

impl MetaNameSpace {
/// Pop latest expression and return the input(s) of the popped expression.
pub fn pop(self) -> PolarsResult<Vec<Expr>> {
pub fn pop(self, _schema: Option<&Schema>) -> PolarsResult<Vec<Expr>> {
let mut arena = Arena::with_capacity(8);
let node = to_aexpr(self.0, &mut arena)?;
let ae = arena.get(node);
Expand All @@ -30,7 +30,7 @@ impl MetaNameSpace {
}

/// A projection that only takes a column or a column + alias.
pub fn is_simple_projection(&self) -> bool {
pub fn is_simple_projection(&self, _schema: Option<&Schema>) -> bool {
let mut arena = Arena::with_capacity(8);
to_aexpr(self.0.clone(), &mut arena)
.map(|node| aexpr_is_simple_projection(node, &arena))
Expand Down Expand Up @@ -172,16 +172,18 @@ impl MetaNameSpace {

/// Get a hold to an implementor of the `Display` trait that will format as
/// the expression as a tree
pub fn into_tree_formatter(self, display_as_dot: bool) -> PolarsResult<impl Display> {
pub fn into_tree_formatter(
self,
display_as_dot: bool,
_schema: Option<&Schema>,
) -> PolarsResult<impl Display> {
let mut arena = Default::default();
let node = to_aexpr(self.0, &mut arena)?;
let mut visitor = TreeFmtVisitor::default();
if display_as_dot {
visitor.display = TreeFmtVisitorDisplay::DisplayDot;
}

AexprNode::new(node).visit(&mut visitor, &arena)?;

Ok(visitor)
}
}
28 changes: 20 additions & 8 deletions crates/polars-python/src/expr/meta.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
use polars::prelude::Schema;
use pyo3::prelude::*;

use crate::PyExpr;
use crate::error::PyPolarsErr;
use crate::expr::ToPyExprs;
use crate::prelude::Wrap;

#[pymethods]
impl PyExpr {
fn meta_eq(&self, other: Self) -> bool {
self.inner == other.inner
}

fn meta_pop(&self) -> PyResult<Vec<Self>> {
let exprs = self.inner.clone().meta().pop().map_err(PyPolarsErr::from)?;
fn meta_pop(&self, schema: Option<Wrap<Schema>>) -> PyResult<Vec<Self>> {
let schema = schema.as_ref().map(|s| &s.0);
let exprs = self
.inner
.clone()
.meta()
.pop(schema)
.map_err(PyPolarsErr::from)?;
Ok(exprs.to_pyexprs())
}

Expand Down Expand Up @@ -106,21 +114,25 @@ impl PyExpr {
self.inner.clone().meta()._into_selector().into()
}

fn compute_tree_format(&self, display_as_dot: bool) -> Result<String, PyErr> {
fn compute_tree_format(
&self,
display_as_dot: bool,
schema: Option<Wrap<Schema>>,
) -> Result<String, PyErr> {
let e = self
.inner
.clone()
.meta()
.into_tree_formatter(display_as_dot)
.into_tree_formatter(display_as_dot, schema.as_ref().map(|s| &s.0))
.map_err(PyPolarsErr::from)?;
Ok(format!("{e}"))
}

fn meta_tree_format(&self) -> PyResult<String> {
self.compute_tree_format(false)
fn meta_tree_format(&self, schema: Option<Wrap<Schema>>) -> PyResult<String> {
self.compute_tree_format(false, schema)
}

fn meta_show_graph(&self) -> PyResult<String> {
self.compute_tree_format(true)
fn meta_show_graph(&self, schema: Option<Wrap<Schema>>) -> PyResult<String> {
self.compute_tree_format(true, schema)
}
}
2 changes: 1 addition & 1 deletion crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1313,7 +1313,7 @@ impl SQLContext {

// Note: if simple aliased expression we defer aliasing until after the group_by.
if let Expr::Alias(expr, alias) = e {
if e.clone().meta().is_simple_projection() {
if e.clone().meta().is_simple_projection(Some(&schema_before)) {
group_key_aliases.insert(alias.as_ref());
e = expr
} else if let Expr::Function {
Expand Down
23 changes: 15 additions & 8 deletions py-polars/polars/expr/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pathlib import Path

from polars import Expr
from polars._typing import SerializationFormat
from polars._typing import SchemaDict, SerializationFormat

if sys.version_info >= (3, 13):
from warnings import deprecated
Expand Down Expand Up @@ -213,7 +213,7 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None:
return None
raise

def pop(self) -> list[Expr]:
def pop(self, *, schema: SchemaDict | None = None) -> list[Expr]:
"""
Pop the latest expression and return the input(s) of the popped expression.

Expand All @@ -233,7 +233,7 @@ def pop(self) -> list[Expr]:
>>> first.meta == pl.col("bar")
False
"""
return [wrap_expr(e) for e in self._pyexpr.meta_pop()]
return [wrap_expr(e) for e in self._pyexpr.meta_pop(schema)]

def root_names(self) -> list[str]:
"""
Expand Down Expand Up @@ -375,12 +375,18 @@ def write_json(self, file: IOBase | str | Path | None = None) -> str | None:
return self.serialize(file, format="json")

@overload
def tree_format(self, *, return_as_string: Literal[False]) -> None: ...
def tree_format(
self, *, return_as_string: Literal[False], schema: None | SchemaDict = None
) -> None: ...

@overload
def tree_format(self, *, return_as_string: Literal[True]) -> str: ...
def tree_format(
self, *, return_as_string: Literal[True], schema: None | SchemaDict = None
) -> str: ...

def tree_format(self, *, return_as_string: bool = False) -> str | None:
def tree_format(
self, *, return_as_string: bool = False, schema: None | SchemaDict = None
) -> str | None:
"""
Format the expression as a tree.

Expand All @@ -394,7 +400,7 @@ def tree_format(self, *, return_as_string: bool = False) -> str | None:
>>> e = (pl.col("foo") * pl.col("bar")).sum().over(pl.col("ham")) / 2
>>> e.meta.tree_format(return_as_string=True) # doctest: +SKIP
"""
s = self._pyexpr.meta_tree_format()
s = self._pyexpr.meta_tree_format(schema)
if return_as_string:
return s
else:
Expand All @@ -408,6 +414,7 @@ def show_graph(
output_path: str | Path | None = None,
raw_output: bool = False,
figsize: tuple[float, float] = (16.0, 12.0),
schema: None | SchemaDict = None,
) -> str | None:
"""
Format the expression as a Graphviz graph.
Expand All @@ -431,7 +438,7 @@ def show_graph(
>>> e = (pl.col("foo") * pl.col("bar")).sum().over(pl.col("ham")) / 2
>>> e.meta.show_graph() # doctest: +SKIP
"""
dot = self._pyexpr.meta_show_graph()
dot = self._pyexpr.meta_show_graph(schema)
return display_dot_graph(
dot=dot,
show=show,
Expand Down
Loading