Skip to content

Commit 8ddb338

Browse files
authored
Merge pull request #746 from dolthub/max/window-refactor
Window exec uses framing iterator and support agg functions in windows
2 parents 62a6525 + 5d7a516 commit 8ddb338

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3109
-2043
lines changed

enginetest/memory_engine_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ func TestNaturalJoin(t *testing.T) {
542542
enginetest.TestNaturalJoin(t, enginetest.NewDefaultMemoryHarness())
543543
}
544544

545-
func TestTestWindowFunctions(t *testing.T) {
545+
func TestWindowFunctions(t *testing.T) {
546546
enginetest.TestWindowFunctions(t, enginetest.NewDefaultMemoryHarness())
547547
}
548548

enginetest/queries.go

Lines changed: 111 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -726,9 +726,9 @@ var QueryTests = []QueryTest{
726726
Query: `select mt.i,
727727
((
728728
select count(*) from mytable
729-
where i in (
730-
select mt2.i from mytable mt2 where mt2.i > mt.i
731-
)
729+
where i in (
730+
select mt2.i from mytable mt2 where mt2.i > mt.i
731+
)
732732
)) as greater_count
733733
from mytable mt order by 1`,
734734
Expected: []sql.Row{{1, 2}, {2, 1}, {3, 0}},
@@ -737,9 +737,9 @@ var QueryTests = []QueryTest{
737737
Query: `select mt.i,
738738
((
739739
select count(*) from mytable
740-
where i in (
741-
select mt2.i from mytable mt2 where mt2.i = mt.i
742-
)
740+
where i in (
741+
select mt2.i from mytable mt2 where mt2.i = mt.i
742+
)
743743
)) as eq_count
744744
from mytable mt order by 1`,
745745
Expected: []sql.Row{{1, 1}, {2, 1}, {3, 1}},
@@ -2755,6 +2755,42 @@ var QueryTests = []QueryTest{
27552755
{1, 3},
27562756
},
27572757
},
2758+
{
2759+
Query: `select pk,
2760+
row_number() over (order by pk desc),
2761+
sum(v1) over (partition by v2 order by pk),
2762+
percent_rank() over(partition by v2 order by pk)
2763+
from one_pk_three_idx order by pk`,
2764+
Expected: []sql.Row{
2765+
{0, 8, float64(0), float64(0)},
2766+
{1, 7, float64(0), float64(1) / float64(3)},
2767+
{2, 6, float64(0), float64(0)},
2768+
{3, 5, float64(0), float64(0)},
2769+
{4, 4, float64(1), float64(2) / float64(3)},
2770+
{5, 3, float64(3), float64(1)},
2771+
{6, 2, float64(3), float64(0)},
2772+
{7, 1, float64(4), float64(0)},
2773+
},
2774+
},
2775+
{
2776+
Query: `select pk,
2777+
first_value(pk) over (order by pk desc),
2778+
lag(pk, 1) over (order by pk desc),
2779+
count(pk) over(partition by v1 order by pk),
2780+
max(pk) over(partition by v1 order by pk desc),
2781+
avg(v2) over (partition by v1 order by pk)
2782+
from one_pk_three_idx order by pk`,
2783+
Expected: []sql.Row{
2784+
{0, 7, 1, 1, 3, float64(0)},
2785+
{1, 7, 2, 2, 3, float64(0)},
2786+
{2, 7, 3, 3, 3, float64(1) / float64(3)},
2787+
{3, 7, 4, 4, 3, float64(3) / float64(4)},
2788+
{4, 7, 5, 1, 4, float64(0)},
2789+
{5, 7, 6, 1, 5, float64(0)},
2790+
{6, 7, 7, 1, 6, float64(3)},
2791+
{7, 7, nil, 1, 7, float64(4)},
2792+
},
2793+
},
27582794
{
27592795
Query: "SELECT t1.i FROM mytable t1 JOIN mytable t2 on t1.i = t2.i + 1 where t1.i = 2 and t2.i = 3",
27602796
Expected: []sql.Row{},
@@ -5720,7 +5756,7 @@ var QueryTests = []QueryTest{
57205756
},
57215757
},
57225758
{
5723-
Query: `select i, row_number() over (order by i desc),
5759+
Query: `select i, row_number() over (order by i desc),
57245760
row_number() over (order by length(s),i) from mytable order by 1;`,
57255761
Expected: []sql.Row{
57265762
{1, 3, 1},
@@ -5743,7 +5779,7 @@ var QueryTests = []QueryTest{
57435779
},
57445780
},
57455781
{
5746-
Query: `select row_number() over (order by i desc),
5782+
Query: `select row_number() over (order by i desc),
57475783
row_number() over (order by length(s),i) from mytable order by i;`,
57485784
Expected: []sql.Row{
57495785
{3, 1},
@@ -5752,7 +5788,7 @@ var QueryTests = []QueryTest{
57525788
},
57535789
},
57545790
{
5755-
Query: `select *, row_number() over (order by i desc),
5791+
Query: `select *, row_number() over (order by i desc),
57565792
row_number() over (order by length(s),i) from mytable order by i;`,
57575793
Expected: []sql.Row{
57585794
{1, "first row", 3, 1},
@@ -5761,10 +5797,10 @@ var QueryTests = []QueryTest{
57615797
},
57625798
},
57635799
{
5764-
Query: `select row_number() over (order by i desc),
5765-
row_number() over (order by length(s),i)
5766-
from mytable mt join othertable ot
5767-
on mt.i = ot.i2
5800+
Query: `select row_number() over (order by i desc),
5801+
row_number() over (order by length(s),i)
5802+
from mytable mt join othertable ot
5803+
on mt.i = ot.i2
57685804
order by mt.i;`,
57695805
Expected: []sql.Row{
57705806
{3, 1},
@@ -5773,7 +5809,7 @@ var QueryTests = []QueryTest{
57735809
},
57745810
},
57755811
{
5776-
Query: `select i, row_number() over (order by i desc),
5812+
Query: `select i, row_number() over (order by i desc),
57775813
row_number() over (order by length(s),i) from mytable order by 1 desc;`,
57785814
Expected: []sql.Row{
57795815
{3, 1, 2},
@@ -5792,8 +5828,8 @@ var QueryTests = []QueryTest{
57925828
},
57935829
{
57945830
Query: `select i, row_number() over (order by i desc) + 3,
5795-
row_number() over (order by length(s),i) as s_asc,
5796-
row_number() over (order by length(s) desc,i desc) as s_desc
5831+
row_number() over (order by length(s),i) as s_asc,
5832+
row_number() over (order by length(s) desc,i desc) as s_desc
57975833
from mytable order by 1;`,
57985834
Expected: []sql.Row{
57995835
{1, 6, 1, 3},
@@ -5803,7 +5839,7 @@ var QueryTests = []QueryTest{
58035839
},
58045840
{
58055841
Query: `select i, row_number() over (order by i desc) + 3,
5806-
row_number() over (order by length(s),i) + 0.0 / row_number() over (order by length(s) desc,i desc) + 0.0
5842+
row_number() over (order by length(s),i) + 0.0 / row_number() over (order by length(s) desc,i desc) + 0.0
58075843
from mytable order by 1;`,
58085844
Expected: []sql.Row{
58095845
{1, 6, 1.0},
@@ -5821,8 +5857,8 @@ var QueryTests = []QueryTest{
58215857
},
58225858
},
58235859
{
5824-
Query: `select pk1, pk2,
5825-
row_number() over (partition by pk1 order by c1 desc)
5860+
Query: `select pk1, pk2,
5861+
row_number() over (partition by pk1 order by c1 desc)
58265862
from two_pk order by 1,2;`,
58275863
Expected: []sql.Row{
58285864
{0, 0, 2},
@@ -5832,8 +5868,8 @@ var QueryTests = []QueryTest{
58325868
},
58335869
},
58345870
{
5835-
Query: `select pk1, pk2,
5836-
row_number() over (partition by pk1 order by c1 desc),
5871+
Query: `select pk1, pk2,
5872+
row_number() over (partition by pk1 order by c1 desc),
58375873
row_number() over (partition by pk2 order by 10 - c1)
58385874
from two_pk order by 1,2;`,
58395875
Expected: []sql.Row{
@@ -5844,8 +5880,8 @@ var QueryTests = []QueryTest{
58445880
},
58455881
},
58465882
{
5847-
Query: `select pk1, pk2,
5848-
row_number() over (partition by pk1 order by c1 desc),
5883+
Query: `select pk1, pk2,
5884+
row_number() over (partition by pk1 order by c1 desc),
58495885
row_number() over (partition by pk2 order by 10 - c1),
58505886
max(c4) over ()
58515887
from two_pk order by 1,2;`,
@@ -5856,6 +5892,58 @@ var QueryTests = []QueryTest{
58565892
{1, 1, 1, 1, 33},
58575893
},
58585894
},
5895+
{
5896+
Query: "SELECT pk, row_number() over (partition by v2 order by pk ), max(v3) over (partition by v2 order by pk) FROM one_pk_three_idx ORDER BY pk",
5897+
Expected: []sql.Row{
5898+
{0, 1, 3},
5899+
{1, 2, 3},
5900+
{2, 1, 0},
5901+
{3, 1, 2},
5902+
{4, 3, 3},
5903+
{5, 4, 3},
5904+
{6, 1, 0},
5905+
{7, 1, 4},
5906+
},
5907+
},
5908+
{
5909+
Query: "SELECT pk, count(*) over (order by v2) FROM one_pk_three_idx ORDER BY pk",
5910+
Expected: []sql.Row{
5911+
{0, 4},
5912+
{1, 4},
5913+
{2, 5},
5914+
{3, 6},
5915+
{4, 4},
5916+
{5, 4},
5917+
{6, 7},
5918+
{7, 8},
5919+
},
5920+
},
5921+
{
5922+
Query: "SELECT pk, count(*) over (partition by v2) FROM one_pk_three_idx ORDER BY pk",
5923+
Expected: []sql.Row{
5924+
{0, 4},
5925+
{1, 4},
5926+
{2, 1},
5927+
{3, 1},
5928+
{4, 4},
5929+
{5, 4},
5930+
{6, 1},
5931+
{7, 1},
5932+
},
5933+
},
5934+
{
5935+
Query: "SELECT pk, row_number() over (order by v2, pk), max(pk) over () from one_pk_three_idx ORDER BY pk",
5936+
Expected: []sql.Row{
5937+
{0, 1, 7},
5938+
{1, 2, 7},
5939+
{2, 5, 7},
5940+
{3, 6, 7},
5941+
{4, 3, 7},
5942+
{5, 4, 7},
5943+
{6, 7, 7},
5944+
{7, 8, 7},
5945+
},
5946+
},
58595947
{
58605948
Query: `select i,
58615949
row_number() over (partition by case when i > 2 then "under two" else "over two" end order by i desc) as s_asc

optgen/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Optgen
2+
3+
Optgen is a small setup for generating analyzer code from templates derived
4+
from [cockroachdb's optgen](https://github.com/cockroachdb/cockroach/tree/master/pkg/sql/opt/optgen/cmd/optgen).
5+
6+
Usage:
7+
```bash
8+
$ go install ./optgen/cmd/optgen
9+
$ go generate ./...
10+
```
11+
12+
The bulk of analyzer logic is normalization rules, join ordering transforms,
13+
and execution specific code.
14+
Analyzer expressions are mostly boilerplate, and specific normalization
15+
rules and join transforms only need types, fields, and literal values to
16+
manipulate logical query plans.
17+
18+
Leaning into templates and this harness can reduce code footprint and standardize optimizer nodes
19+
when the opportunities arise.
20+
21+
If we end up using more of cockroach DB's optimizer, they codegen their
22+
expressions, normalization rules, exploration rules, and execution code
23+
stem from this general setup.

optgen/cmd/optgen/main.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"flag"
7+
"fmt"
8+
"go/format"
9+
"io"
10+
"os"
11+
12+
"github.com/dolthub/go-mysql-server/optgen/cmd/support"
13+
"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
14+
)
15+
16+
var (
17+
errInvalidArgCount = errors.New("invalid number of arguments")
18+
errUnrecognizedCommand = errors.New("unrecognized command")
19+
)
20+
21+
var (
22+
pkg = flag.String("pkg", "aggregation", "package name used in generated files")
23+
out = flag.String("out", "", "output file name of generated code")
24+
)
25+
26+
const useGoFmt = true
27+
28+
func main() {
29+
flag.Usage = usage
30+
flag.Parse()
31+
32+
args := flag.Args()
33+
if len(args) < 2 {
34+
flag.Usage()
35+
exit(errInvalidArgCount)
36+
}
37+
38+
cmd := args[0]
39+
switch cmd {
40+
case "aggs":
41+
42+
default:
43+
flag.Usage()
44+
exit(errUnrecognizedCommand)
45+
}
46+
47+
sources := flag.Args()[1:]
48+
readers := make([]io.Reader, len(sources))
49+
for i, name := range sources {
50+
file, err := os.Open(name)
51+
if err != nil {
52+
exit(err)
53+
}
54+
55+
defer file.Close()
56+
readers[i] = file
57+
}
58+
59+
var writer io.Writer
60+
if *out != "" {
61+
file, err := os.Create(*out)
62+
if err != nil {
63+
exit(err)
64+
}
65+
66+
defer file.Close()
67+
writer = file
68+
} else {
69+
writer = os.Stderr
70+
}
71+
72+
var err error
73+
switch cmd {
74+
case "aggs":
75+
err = generateAggs(aggregation.UnaryAggDefs, writer)
76+
}
77+
78+
if err != nil {
79+
exit(err)
80+
}
81+
}
82+
83+
// usage is a replacement usage function for the flags package.
84+
func usage() {
85+
fmt.Fprintf(os.Stderr, "Optgen is a tool for generating optimizer code.\n\n")
86+
fmt.Fprintf(os.Stderr, "Usage:\n")
87+
88+
fmt.Fprintf(os.Stderr, "\toptgen command [flags] sources...\n\n")
89+
90+
fmt.Fprintf(os.Stderr, "The commands are:\n\n")
91+
fmt.Fprintf(os.Stderr, "\taggs generates aggregation definitions and functions\n")
92+
fmt.Fprintf(os.Stderr, "\n")
93+
94+
fmt.Fprintf(os.Stderr, "Flags:\n")
95+
96+
flag.PrintDefaults()
97+
98+
fmt.Fprintf(os.Stderr, "\n")
99+
}
100+
101+
func exit(err error) {
102+
fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)
103+
os.Exit(2)
104+
}
105+
106+
func generateAggs(defines []support.AggDef, w io.Writer) error {
107+
var gen support.AggGen
108+
return generate(defines, w, gen.Generate)
109+
}
110+
111+
func generate(defines []support.AggDef, w io.Writer, genFunc func(defines []support.AggDef, w io.Writer)) error {
112+
var buf bytes.Buffer
113+
114+
buf.WriteString("// Code generated by optgen; DO NOT EDIT.\n\n")
115+
fmt.Fprintf(&buf, " package %s\n\n", *pkg)
116+
117+
genFunc(defines, &buf)
118+
119+
var b []byte
120+
var err error
121+
122+
if useGoFmt {
123+
b, err = format.Source(buf.Bytes())
124+
if err != nil {
125+
// Write out incorrect source for easier debugging.
126+
b = buf.Bytes()
127+
}
128+
} else {
129+
b = buf.Bytes()
130+
}
131+
132+
w.Write(b)
133+
return err
134+
}

0 commit comments

Comments
 (0)