Skip to content

Commit d2e4b47

Browse files
authored
Merge pull request #749 from dolthub/max/groupby-flatten-bug
GroupBy normalization rule maintains parent Projection dependencies
2 parents 91a637b + 0b8719c commit d2e4b47

File tree

3 files changed

+73
-2
lines changed

3 files changed

+73
-2
lines changed

enginetest/queries.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5779,6 +5779,14 @@ var QueryTests = []QueryTest{
57795779
{3, 3},
57805780
},
57815781
},
5782+
{
5783+
Query: "select sum(x.i) + y.i from mytable as x, mytable as y where x.i = y.i GROUP BY x.i",
5784+
Expected: []sql.Row{
5785+
{float64(2)},
5786+
{float64(4)},
5787+
{float64(6)},
5788+
},
5789+
},
57825790
{
57835791
Query: "SELECT 2.0 + CAST(5 AS DECIMAL)",
57845792
Expected: []sql.Row{{float64(7)}},

sql/analyzer/aggregations.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,18 @@ func flattenedGroupBy(ctx *sql.Context, projection, grouping []sql.Expression, c
7575
func replaceAggregatesWithGetFieldProjections(ctx *sql.Context, projection []sql.Expression) (projections, aggregations []sql.Expression, err error) {
7676
var newProjection = make([]sql.Expression, len(projection))
7777
var newAggregates []sql.Expression
78-
78+
allGetFields := make(map[int]sql.Expression)
79+
projDeps := make(map[int]struct{})
7980
for i, p := range projection {
8081
var transformed bool
8182
e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) {
8283
switch e := e.(type) {
8384
case sql.Aggregation, sql.WindowAggregation:
84-
// continue on
85+
// continue on
86+
case *expression.GetField:
87+
allGetFields[e.Index()] = e
88+
projDeps[e.Index()] = struct{}{}
89+
return e, nil
8590
default:
8691
return e, nil
8792
}
@@ -107,6 +112,24 @@ func replaceAggregatesWithGetFieldProjections(ctx *sql.Context, projection []sql
107112
}
108113
}
109114

115+
// find subset of allGetFields not covered by newAggregates
116+
newAggDeps := make(map[int]struct{}, 0)
117+
for _, agg := range newAggregates {
118+
_, _ = expression.TransformUp(agg, func(e sql.Expression) (sql.Expression, error) {
119+
switch e := e.(type) {
120+
case *expression.GetField:
121+
newAggDeps[e.Index()] = struct{}{}
122+
}
123+
return e, nil
124+
})
125+
}
126+
for i, _ := range projDeps {
127+
if _, ok := newAggDeps[i]; !ok {
128+
// add pass-through dependency
129+
newAggregates = append(newAggregates, allGetFields[i])
130+
}
131+
}
132+
110133
return newProjection, newAggregates, nil
111134
}
112135

sql/analyzer/aggregations_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,46 @@ func TestFlattenAggregationExprs(t *testing.T) {
168168
),
169169
),
170170
},
171+
{
172+
name: "aggregate with pass through column dependency",
173+
node: plan.NewGroupBy(
174+
[]sql.Expression{
175+
expression.NewArithmetic(
176+
aggregation.NewSum(
177+
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
178+
),
179+
expression.NewGetFieldWithTable(1, sql.Int64, "bar", "a", false),
180+
"+",
181+
),
182+
},
183+
[]sql.Expression{
184+
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
185+
},
186+
plan.NewResolvedTable(table, nil, nil),
187+
),
188+
189+
expected: plan.NewProject(
190+
[]sql.Expression{
191+
expression.NewArithmetic(
192+
expression.NewGetField(0, sql.Float64, "SUM(foo.a)", false),
193+
expression.NewGetFieldWithTable(1, sql.Int64, "bar", "a", false),
194+
"+",
195+
),
196+
},
197+
plan.NewGroupBy(
198+
[]sql.Expression{
199+
aggregation.NewSum(
200+
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
201+
),
202+
expression.NewGetFieldWithTable(1, sql.Int64, "bar", "a", false),
203+
},
204+
[]sql.Expression{
205+
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
206+
},
207+
plan.NewResolvedTable(table, nil, nil),
208+
),
209+
),
210+
},
171211
}
172212

173213
for _, test := range tests {

0 commit comments

Comments
 (0)