Skip to content

[SPARK-52311][SQL] Fix bug with multiple self-references with UnresolvedSubqueries in rCTEs #51022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
import org.apache.spark.sql.errors.QueryCompilationErrors

/**
* Updates CTE references with the resolve output attributes of corresponding CTE definitions.
Expand Down Expand Up @@ -107,7 +108,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
// The case of CTE name followed by a parenthesized list of column name(s), eg.
// WITH RECURSIVE t(n).
case alias @ SubqueryAlias(_,
columnAlias @ UnresolvedSubqueryColumnAliases(
_ @ UnresolvedSubqueryColumnAliases(
colNames,
Union(Seq(anchor, recursion), false, false)
)) =>
Expand All @@ -116,16 +117,16 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
} else {
val loop = UnionLoop(
cteDef.id,
anchor,
UnresolvedSubqueryColumnAliases(colNames, anchor),
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
maxDepth = cteDef.maxDepth)
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
cteDef.copy(child = alias.copy(child = loop))
}

// The case of CTE name followed by a parenthesized list of column name(s), eg.
// WITH RECURSIVE t(n).
case alias @ SubqueryAlias(_,
columnAlias @ UnresolvedSubqueryColumnAliases(
_ @ UnresolvedSubqueryColumnAliases(
colNames,
withCTE @ WithCTE(Union(Seq(anchor, recursion), false, false), innerCteDefs)
)) =>
Expand All @@ -140,11 +141,11 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
}
val loop = UnionLoop(
cteDef.id,
anchor,
UnresolvedSubqueryColumnAliases(colNames, anchor),
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
maxDepth = cteDef.maxDepth)
cteDef.copy(child = alias.copy(child = columnAlias.copy(
child = withCTE.copy(plan = loop, cteDefs = newInnerCteDefs))))
cteDef.copy(child = alias.copy(child =
withCTE.copy(plan = loop, cteDefs = newInnerCteDefs)))
}

// If the recursion is described with a UNION (deduplicating) clause then the
Expand Down Expand Up @@ -202,7 +203,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {

// The case of CTE name followed by a parenthesized list of column name(s).
case alias @ SubqueryAlias(_,
columnAlias@UnresolvedSubqueryColumnAliases(
_ @ UnresolvedSubqueryColumnAliases(
colNames,
Distinct(Union(Seq(anchor, recursion), false, false))
)) =>
Expand All @@ -214,20 +215,20 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
} else {
val loop = UnionLoop(
cteDef.id,
Distinct(anchor),
UnresolvedSubqueryColumnAliases(colNames, Distinct(anchor)),
Except(
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
UnionLoopRef(cteDef.id, anchor.output, true),
isAll = false
),
maxDepth = cteDef.maxDepth
)
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
cteDef.copy(child = alias.copy(child = loop))
}

// The case of CTE name followed by a parenthesized list of column name(s).
case alias @ SubqueryAlias(_,
columnAlias@UnresolvedSubqueryColumnAliases(
_ @ UnresolvedSubqueryColumnAliases(
colNames,
WithCTE(Distinct(Union(Seq(anchor, recursion), false, false)), innerCteDefs)
)) =>
Expand All @@ -245,16 +246,16 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
}
val loop = UnionLoop(
cteDef.id,
Distinct(anchor),
UnresolvedSubqueryColumnAliases(colNames, Distinct(anchor)),
Except(
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
UnionLoopRef(cteDef.id, anchor.output, true),
isAll = false
),
maxDepth = cteDef.maxDepth
)
cteDef.copy(child = alias.copy(child = columnAlias.copy(
child = withCTE.copy(plan = loop, cteDefs = newInnerCteDefs))))
cteDef.copy(child = alias.copy(child =
withCTE.copy(plan = loop, cteDefs = newInnerCteDefs)))
}

case other =>
Expand Down Expand Up @@ -298,8 +299,24 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
columnNames: Option[Seq[String]]) = {
recursion.transformUpWithSubqueriesAndPruning(_.containsPattern(CTE)) {
case r: CTERelationRef if r.recursive && r.cteId == cteDefId =>
val ref = UnionLoopRef(r.cteId, anchor.output, false)
columnNames.map(UnresolvedSubqueryColumnAliases(_, ref)).getOrElse(ref)
columnNames match {
case Some(names) =>
val outputAttrs = anchor.output
// Checks if the number of the aliases equals to the number of output columns
// in the subquery.
if (names.size != outputAttrs.size) {
throw QueryCompilationErrors.aliasNumberNotMatchColumnNumberError(
names.size, outputAttrs.size, r)
}
val aliases = outputAttrs.zip(names).map { case (attr, aliasName) =>
Alias(attr, aliasName)()
}
UnionLoopRef(r.cteId, aliases.map(_.toAttribute), accumulated = false)

case None =>
UnionLoopRef(r.cteId, anchor.output, accumulated = false)

}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ class ResolveRecursiveCTESuite extends AnalysisTest {

def getAfterPlan(): LogicalPlan = {
val col = anchor.output.head
val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false)
.select(col.as("n"))
val newAnchor = anchor.select(col.as("n"))
val recursion = UnionLoopRef(cteId, newAnchor.output, accumulated = false)
.subquery("t")
val cteDef = CTERelationDef(
UnionLoop(cteId, anchor, recursion).select(col.as("n")).subquery("t"),
UnionLoop(cteId, newAnchor, recursion).subquery("t"),
cteId)
val cteRef = CTERelationRef(
cteId,
Expand Down
Loading