Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 80 additions & 42 deletions insights/insights/doctype/insights_data_source_v3/ibis_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ast
import re
import time
from contextlib import contextmanager
from datetime import date
Expand Down Expand Up @@ -500,29 +499,21 @@ def apply_sql(self, sql_args):
ds = frappe.get_doc("Insights Data Source v3", data_source)
db = ds._get_ibis_backend()

raw_sql = sqlparse.format(sql=raw_sql, strip_comments=True)
raw_sql = sqlparse.format(sql=raw_sql, strip_comments=True).strip()

# TODO: apply user permissions by default
check_permissions = frappe.db.get_single_value(
"Insights Settings", "enable_permissions"
) or frappe.db.get_single_value("Insights Settings", "apply_user_permissions")

if check_permissions:
parsed = sg.parse_one(raw_sql, dialect=db.dialect)

tables = set()
for table_exp in parsed.find_all(sg.exp.Table):
if table_exp.name:
tables.add(table_exp.name)

cte_aliases = set()
for cte_exp in parsed.find_all(sg.exp.CTE):
if cte_exp.alias:
cte_aliases.add(cte_exp.alias)
parsed = sg.parse_one(raw_sql, dialect=db.dialect)
tables = {t.name for t in parsed.find_all(sg.exp.Table) if t.name}
cte_aliases = {c.alias for c in parsed.find_all(sg.exp.CTE)}
tables -= cte_aliases | _VIRTUAL_TABLES

tables = tables - cte_aliases
ctes_to_prepend = {}

replace_map = {}
if check_permissions:
for table_name in tables:
t = InsightsTablev3.get_ibis_table(
data_source,
Expand All @@ -540,36 +531,23 @@ def apply_sql(self, sql_args):
t_parsed = sg.parse_one(t_sql, dialect=db.dialect)
if not t_parsed.find(sg.exp.Where):
continue
replace_map[table_name] = t_sql

with_clauses = []
for table_name, table_sql in replace_map.items():
quoted_table_name = sg.to_identifier(table_name)
with_clauses.append(f"{quoted_table_name} AS ({table_sql})")

if with_clauses:
with_clause_sql = ", ".join(with_clauses)
# Check if raw_sql already starts with WITH clause
raw_sql_stripped = raw_sql.strip()
if raw_sql_stripped.lower().startswith("with"):
# Insert new CTEs after the WITH keyword and before existing CTEs
# Use regex to handle both uppercase and lowercase "with"
raw_sql = re.sub(
r"(\bwith\b)",
f"WITH {with_clause_sql},",
raw_sql_stripped,
count=1,
flags=re.IGNORECASE,
)
else:
# Prepend WITH clause if it doesn't exist
raw_sql = f"WITH {with_clause_sql} {raw_sql_stripped}"
ctes_to_prepend[table_name] = t_sql

# Inject CTEs for built-in virtual tables (e.g., insights_calendar)
virtual_refs = {t.name for t in parsed.find_all(sg.exp.Table) if t.name in _VIRTUAL_TABLES}
virtual_refs -= cte_aliases
for vt in virtual_refs:
ctes_to_prepend[vt] = _get_virtual_table_sql(vt, ds.database_type)

if ctes_to_prepend:
parsed = _prepend_ctes(parsed, ctes_to_prepend, db.dialect)
raw_sql = parsed.sql(dialect=db.dialect)

supports_stored_procedure = ds.database_type in ["PostgreSQL", "MSSQL", "MariaDB"]
if (
supports_stored_procedure
and ds.enable_stored_procedure_execution
and raw_sql.strip().lower().startswith("exec")
and raw_sql.lower().startswith("exec")
):
current_date = date.today().strftime("%Y-%m-%d") # Format: 'YYYY-MM-DD'
raw_sql = raw_sql.replace("@Today", f"'{current_date}'")
Expand All @@ -583,7 +561,7 @@ def apply_sql(self, sql_args):

results = ibis.memtable(df)

elif raw_sql.strip().lower().startswith(("select", "with")):
elif raw_sql.lower().startswith(("select", "with")):
results = db.sql(raw_sql)

else:
Expand Down Expand Up @@ -841,6 +819,66 @@ def exec_with_return(
return safe_eval(output_expression, _globals, _locals)


_VIRTUAL_TABLES = frozenset(["insights_calendar"])


def _prepend_ctes(parsed, cte_map, dialect):
"""Prepend CTE definitions to a parsed SQL statement using the sqlglot AST.

New CTEs are placed before any existing CTEs so they can be referenced
by user-defined CTEs that follow.
"""
existing_with = parsed.args.get("with")
existing_ctes = list(existing_with.expressions) if existing_with else []
existing_names = {cte.alias for cte in existing_ctes}

new_ctes = []
for name, sql in cte_map.items():
if name in existing_names:
continue
cte_query = sg.parse_one(sql, dialect=dialect)
new_cte = sg.exp.CTE(
this=cte_query,
alias=sg.to_identifier(name),
)
new_ctes.append(new_cte)

if not new_ctes:
return parsed

all_ctes = new_ctes + existing_ctes
recursive = existing_with.args.get("recursive") if existing_with else False
parsed.set("with", sg.exp.With(expressions=all_ctes, recursive=recursive))
return parsed


def _get_virtual_table_sql(table_name: str, database_type: str) -> str:
if table_name == "insights_calendar":
return _get_calendar_sql(database_type)
frappe.throw(f"Unknown virtual table: {table_name}")


def _get_calendar_sql(database_type: str) -> str:
if database_type == "DuckDB":
return (
"SELECT CAST(range AS DATE) AS date "
"FROM range(DATE '1970-01-01', DATE '2051-01-01', INTERVAL 1 DAY)"
)
if database_type == "PostgreSQL":
return (
"SELECT gs::DATE AS date "
"FROM generate_series(DATE '1970-01-01', DATE '2050-12-31', INTERVAL '1 day') AS gs"
)
if database_type == "MariaDB":
# Uses MariaDB's built-in Sequence Storage Engine (seq_0_to_N).
# seq 0 = 1970-01-01, seq 29584 = 2050-12-31.
return "SELECT DATE('1970-01-01') + INTERVAL seq DAY AS date FROM seq_0_to_29584"
frappe.throw(
f"insights_calendar is not supported for {database_type} databases",
title="Not Implemented",
)


def get_ibis_table_name(table: IbisQuery):
dt = table.op().find_topmost(DatabaseTable)
if not dt:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def apply_user_permissions(t: Table, data_source, table_name):
return t
return t.filter(t.doctype.isin(allowed_single_doctypes))

if not table_name.startswith("tab"):
return t

permission_query = get_permission_query_for_table(table_name)
if not permission_query:
return t.filter(False)
Expand Down
Loading