Skip to content

Commit 5201a66

Browse files
pavle-martinovic_dataPajaraja
authored andcommitted
[SPARK-52311][SQL] Redefine UnionLoop output to not be duplicated if the anchor output is duplicated
### What changes were proposed in this pull request? Redefine the output of the UnionLoop and UnionLoopRef to be with new expression IDs. ### Why are the changes needed? Currently rCTEs don't behave properly in case when the anchor references the same column multiple times in the anchor, leading to wrong things being identified in the recursion. For example this rCTE: ``` WITH RECURSIVE tmp(x) AS ( values (1), (2), (3), (4), (5) ), rcte(x, y) AS ( SELECT x, x FROM tmp WHERE x = 1 UNION ALL SELECT x + 1, x FROM rcte WHERE x < 5 ) SELECT * FROM rcte; ``` Will return: ``` 1 1 2 2 3 3 4 4 5 5 ``` Instead of: ``` 1 1 2 1 3 2 4 3 5 4 ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests in golden file cte-recursion.sql. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51041 from Pajaraja/pavle-martinovic_data/UnionLoopOutput. Lead-authored-by: pavle-martinovic_data <[email protected]> Co-authored-by: Pavle Martinovic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 6422256 commit 5201a66

File tree

8 files changed

+173
-8
lines changed

8 files changed

+173
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
7979
cteDef.id,
8080
anchor,
8181
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, None),
82+
anchor.output.map(_.newInstance().exprId),
8283
maxDepth = cteDef.maxDepth)
8384
cteDef.copy(child = alias.copy(child = loop))
8485
}
@@ -99,6 +100,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
99100
cteDef.id,
100101
anchor,
101102
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, None),
103+
anchor.output.map(_.newInstance().exprId),
102104
maxDepth = cteDef.maxDepth)
103105
cteDef.copy(child = alias.copy(child = withCTE.copy(
104106
plan = loop, cteDefs = newInnerCteDefs)))
@@ -118,6 +120,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
118120
cteDef.id,
119121
anchor,
120122
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
123+
anchor.output.map(_.newInstance().exprId),
121124
maxDepth = cteDef.maxDepth)
122125
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
123126
}
@@ -142,6 +145,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
142145
cteDef.id,
143146
anchor,
144147
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
148+
anchor.output.map(_.newInstance().exprId),
145149
maxDepth = cteDef.maxDepth)
146150
cteDef.copy(child = alias.copy(child = columnAlias.copy(
147151
child = withCTE.copy(plan = loop, cteDefs = newInnerCteDefs))))
@@ -166,6 +170,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
166170
UnionLoopRef(cteDef.id, anchor.output, true),
167171
isAll = false
168172
),
173+
anchor.output.map(_.newInstance().exprId),
169174
maxDepth = cteDef.maxDepth
170175
)
171176
cteDef.copy(child = alias.copy(child = loop))
@@ -194,6 +199,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
194199
UnionLoopRef(cteDef.id, anchor.output, true),
195200
isAll = false
196201
),
202+
anchor.output.map(_.newInstance().exprId),
197203
maxDepth = cteDef.maxDepth
198204
)
199205
cteDef.copy(child = alias.copy(child = withCTE.copy(
@@ -220,6 +226,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
220226
UnionLoopRef(cteDef.id, anchor.output, true),
221227
isAll = false
222228
),
229+
anchor.output.map(_.newInstance().exprId),
223230
maxDepth = cteDef.maxDepth
224231
)
225232
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
@@ -251,6 +258,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
251258
UnionLoopRef(cteDef.id, anchor.output, true),
252259
isAll = false
253260
),
261+
anchor.output.map(_.newInstance().exprId),
254262
maxDepth = cteDef.maxDepth
255263
)
256264
cteDef.copy(child = alias.copy(child = columnAlias.copy(
@@ -298,7 +306,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
298306
columnNames: Option[Seq[String]]) = {
299307
recursion.transformUpWithSubqueriesAndPruning(_.containsPattern(CTE)) {
300308
case r: CTERelationRef if r.recursive && r.cteId == cteDefId =>
301-
val ref = UnionLoopRef(r.cteId, anchor.output, false)
309+
val ref = UnionLoopRef(r.cteId, anchor.output.map(_.newInstance()), false)
302310
columnNames.map(UnresolvedSubqueryColumnAliases(_, ref)).getOrElse(ref)
303311
}
304312
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ abstract class UnionBase extends LogicalPlan {
510510

511511
private lazy val lazyOutput: Seq[Attribute] = computeOutput()
512512

513-
private def computeOutput(): Seq[Attribute] = Union.mergeChildOutputs(children.map(_.output))
513+
protected def computeOutput(): Seq[Attribute] = Union.mergeChildOutputs(children.map(_.output))
514514

515515
/**
516516
* Maps the constraints containing a given (original) sequence of attributes to those with a

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,31 @@ import org.apache.spark.sql.internal.SQLConf
3434
* @param id The id of the loop, inherited from [[CTERelationDef]] within which the Union lived.
3535
* @param anchor The plan of the initial element of the loop.
3636
* @param recursion The plan that describes the recursion with an [[UnionLoopRef]] node.
37+
* @param outputAttrIds The ids of UnionLoop's output attributes.
3738
* @param limit An optional limit that can be pushed down to the node to stop the loop earlier.
3839
* @param maxDepth Maximal number of iterations before we report an error.
3940
*/
4041
case class UnionLoop(
4142
id: Long,
4243
anchor: LogicalPlan,
4344
recursion: LogicalPlan,
45+
outputAttrIds: Seq[ExprId],
4446
limit: Option[Int] = None,
4547
maxDepth: Option[Int] = None) extends UnionBase {
4648
override def children: Seq[LogicalPlan] = Seq(anchor, recursion)
4749

4850
override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): UnionLoop =
4951
copy(anchor = newChildren(0), recursion = newChildren(1))
52+
53+
override protected def computeOutput(): Seq[Attribute] =
54+
Union.mergeChildOutputs(children.map(_.output)).zip(outputAttrIds).map { case (x, id) =>
55+
x.withExprId(id)
56+
}
57+
58+
override def argString(maxFields: Int): String = {
59+
id.toString + limit.map(", " + _.toString).getOrElse("") +
60+
maxDepth.map(", " + _.toString).getOrElse("")
61+
}
5062
}
5163

5264
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveRecursiveCTESuite.scala

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,20 @@ class ResolveRecursiveCTESuite extends AnalysisTest {
4242
Seq(CTERelationDef(anchor.union(recursion).subquery("t"), cteId)))
4343
}
4444

45+
val analyzed = getAnalyzer.execute(getBeforePlan())
46+
47+
val outputExprIds = analyzed match {
48+
case WithCTE(_, cteDefs) =>
49+
cteDefs.head.child match {
50+
case SubqueryAlias(_, UnionLoop(_, _, _, exprIds, _, _)) =>
51+
exprIds
52+
}
53+
}
54+
4555
def getAfterPlan(): LogicalPlan = {
4656
val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false).subquery("t")
47-
val cteDef = CTERelationDef(UnionLoop(cteId, anchor, recursion).subquery("t"), cteId)
57+
val cteDef = CTERelationDef(UnionLoop(cteId, anchor, recursion,
58+
outputExprIds).subquery("t"), cteId)
4859
val cteRef = CTERelationRef(
4960
cteId,
5061
_resolved = true,
@@ -53,7 +64,7 @@ class ResolveRecursiveCTESuite extends AnalysisTest {
5364
WithCTE(cteRef, Seq(cteDef))
5465
}
5566

56-
comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan())
67+
comparePlans(analyzed, getAfterPlan())
5768
}
5869

5970
// Motivated by:
@@ -75,14 +86,24 @@ class ResolveRecursiveCTESuite extends AnalysisTest {
7586
WithCTE(cteRef.copy(recursive = false), Seq(cteDef))
7687
}
7788

89+
val analyzed = getAnalyzer.execute(getBeforePlan())
90+
91+
val outputExprIds = analyzed match {
92+
case WithCTE(_, cteDefs) =>
93+
cteDefs.head.child match {
94+
case SubqueryAlias(_, Project(_, UnionLoop(_, _, _, exprIds, _, _))) =>
95+
exprIds
96+
}
97+
}
98+
7899
def getAfterPlan(): LogicalPlan = {
79100
val col = anchor.output.head
80101
val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false)
81102
.select(col.as("n"))
82103
.subquery("t")
83104
val cteDef = CTERelationDef(
84-
UnionLoop(cteId, anchor, recursion).select(col.as("n")).subquery("t"),
85-
cteId)
105+
UnionLoop(cteId, anchor, recursion, outputExprIds)
106+
.select(col.as("n")).subquery("t"), cteId)
86107
val cteRef = CTERelationRef(
87108
cteId,
88109
_resolved = true,
@@ -91,6 +112,6 @@ class ResolveRecursiveCTESuite extends AnalysisTest {
91112
WithCTE(cteRef, Seq(cteDef))
92113
}
93114

94-
comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan())
115+
comparePlans(analyzed, getAfterPlan())
95116
}
96117
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
10451045
GlobalLimitExec(child = planLater(child), offset = offset) :: Nil
10461046
case union: logical.Union =>
10471047
execution.UnionExec(union.children.map(planLater)) :: Nil
1048-
case u @ logical.UnionLoop(id, anchor, recursion, limit, maxDepth) =>
1048+
case u @ logical.UnionLoop(id, anchor, recursion, _, limit, maxDepth) =>
10491049
execution.UnionLoopExec(id, anchor, recursion, u.output, limit, maxDepth) :: Nil
10501050
case g @ logical.Generate(generator, _, outer, _, _, child) =>
10511051
execution.GenerateExec(

sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,72 @@ WithCTE
16311631
+- CTERelationRef xxxx, true, [n#x], false, false
16321632

16331633

1634+
-- !query
1635+
WITH RECURSIVE tmp(x) AS (
1636+
values (1), (2), (3), (4), (5)
1637+
), rcte(x, y) AS (
1638+
SELECT x, x FROM tmp WHERE x = 1
1639+
UNION ALL
1640+
SELECT x + 1, x FROM rcte WHERE x < 5
1641+
)
1642+
SELECT * FROM rcte
1643+
-- !query analysis
1644+
WithCTE
1645+
:- CTERelationDef xxxx, false
1646+
: +- SubqueryAlias tmp
1647+
: +- Project [col1#x AS x#x]
1648+
: +- LocalRelation [col1#x]
1649+
:- CTERelationDef xxxx, false
1650+
: +- SubqueryAlias rcte
1651+
: +- Project [x#x AS x#x, x#x AS y#x]
1652+
: +- UnionLoop xxxx
1653+
: :- Project [x#x, x#x]
1654+
: : +- Filter (x#x = 1)
1655+
: : +- SubqueryAlias tmp
1656+
: : +- CTERelationRef xxxx, true, [x#x], false, false, 5
1657+
: +- Project [(x#x + 1) AS (x + 1)#x, x#x]
1658+
: +- Filter (x#x < 5)
1659+
: +- SubqueryAlias rcte
1660+
: +- Project [x#x AS x#x, x#x AS y#x]
1661+
: +- UnionLoopRef xxxx, [x#x, x#x], false
1662+
+- Project [x#x, y#x]
1663+
+- SubqueryAlias rcte
1664+
+- CTERelationRef xxxx, true, [x#x, y#x], false, false
1665+
1666+
1667+
-- !query
1668+
WITH RECURSIVE tmp(x) AS (
1669+
values (1), (2), (3), (4), (5)
1670+
), rcte(x, y, z, t) AS (
1671+
SELECT x, x, x, x FROM tmp WHERE x = 1
1672+
UNION ALL
1673+
SELECT x + 1, x, y + 1, y FROM rcte WHERE x < 5
1674+
)
1675+
SELECT * FROM rcte
1676+
-- !query analysis
1677+
WithCTE
1678+
:- CTERelationDef xxxx, false
1679+
: +- SubqueryAlias tmp
1680+
: +- Project [col1#x AS x#x]
1681+
: +- LocalRelation [col1#x]
1682+
:- CTERelationDef xxxx, false
1683+
: +- SubqueryAlias rcte
1684+
: +- Project [x#x AS x#x, x#x AS y#x, x#x AS z#x, x#x AS t#x]
1685+
: +- UnionLoop xxxx
1686+
: :- Project [x#x, x#x, x#x, x#x]
1687+
: : +- Filter (x#x = 1)
1688+
: : +- SubqueryAlias tmp
1689+
: : +- CTERelationRef xxxx, true, [x#x], false, false, 5
1690+
: +- Project [(x#x + 1) AS (x + 1)#x, x#x, (y#x + 1) AS (y + 1)#x, y#x]
1691+
: +- Filter (x#x < 5)
1692+
: +- SubqueryAlias rcte
1693+
: +- Project [x#x AS x#x, x#x AS y#x, x#x AS z#x, x#x AS t#x]
1694+
: +- UnionLoopRef xxxx, [x#x, x#x, x#x, x#x], false
1695+
+- Project [x#x, y#x, z#x, t#x]
1696+
+- SubqueryAlias rcte
1697+
+- CTERelationRef xxxx, true, [x#x, y#x, z#x, t#x], false, false
1698+
1699+
16341700
-- !query
16351701
WITH RECURSIVE randoms(val) AS (
16361702
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)

sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,26 @@ WITH RECURSIVE t1 AS (
588588
SELECT n+1 FROM t2 WHERE n < 5)
589589
SELECT * FROM t1;
590590

591+
-- Recursive CTE with multiple of the same reference in the anchor, which get referenced differently subsequent iterations.
592+
WITH RECURSIVE tmp(x) AS (
593+
values (1), (2), (3), (4), (5)
594+
), rcte(x, y) AS (
595+
SELECT x, x FROM tmp WHERE x = 1
596+
UNION ALL
597+
SELECT x + 1, x FROM rcte WHERE x < 5
598+
)
599+
SELECT * FROM rcte;
600+
601+
-- Recursive CTE with multiple of the same reference in the anchor, which get referenced as different variables in subsequent iterations.
602+
WITH RECURSIVE tmp(x) AS (
603+
values (1), (2), (3), (4), (5)
604+
), rcte(x, y, z, t) AS (
605+
SELECT x, x, x, x FROM tmp WHERE x = 1
606+
UNION ALL
607+
SELECT x + 1, x, y + 1, y FROM rcte WHERE x < 5
608+
)
609+
SELECT * FROM rcte;
610+
591611
-- Non-deterministic query with rand with seed
592612
WITH RECURSIVE randoms(val) AS (
593613
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)

sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,6 +1477,44 @@ struct<n:int>
14771477
5
14781478

14791479

1480+
-- !query
1481+
WITH RECURSIVE tmp(x) AS (
1482+
values (1), (2), (3), (4), (5)
1483+
), rcte(x, y) AS (
1484+
SELECT x, x FROM tmp WHERE x = 1
1485+
UNION ALL
1486+
SELECT x + 1, x FROM rcte WHERE x < 5
1487+
)
1488+
SELECT * FROM rcte
1489+
-- !query schema
1490+
struct<x:int,y:int>
1491+
-- !query output
1492+
1 1
1493+
2 1
1494+
3 2
1495+
4 3
1496+
5 4
1497+
1498+
1499+
-- !query
1500+
WITH RECURSIVE tmp(x) AS (
1501+
values (1), (2), (3), (4), (5)
1502+
), rcte(x, y, z, t) AS (
1503+
SELECT x, x, x, x FROM tmp WHERE x = 1
1504+
UNION ALL
1505+
SELECT x + 1, x, y + 1, y FROM rcte WHERE x < 5
1506+
)
1507+
SELECT * FROM rcte
1508+
-- !query schema
1509+
struct<x:int,y:int,z:int,t:int>
1510+
-- !query output
1511+
1 1 1 1
1512+
2 1 2 1
1513+
3 2 2 1
1514+
4 3 3 2
1515+
5 4 4 3
1516+
1517+
14801518
-- !query
14811519
WITH RECURSIVE randoms(val) AS (
14821520
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)

0 commit comments

Comments
 (0)