Skip to content

Commit 28c7acf

Browse files
authored
sql/ast: Resolve return values from functions (#964)
* sql/ast: Resolve return values from functions * Fix early return with err * Add endtoend test
1 parent eec0fe1 commit 28c7acf

File tree

10 files changed

+178
-19
lines changed

10 files changed

+178
-19
lines changed

internal/compiler/output_columns.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
229229
case *ast.SelectStmt:
230230
list = astutils.Search(n.FromClause, func(node ast.Node) bool {
231231
switch node.(type) {
232-
case *ast.RangeVar, *ast.RangeSubselect:
232+
case *ast.RangeVar, *ast.RangeSubselect, *ast.FuncName:
233233
return true
234234
default:
235235
return false
@@ -251,6 +251,20 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
251251
var tables []*Table
252252
for _, item := range list.Items {
253253
switch n := item.(type) {
254+
case *ast.FuncName:
255+
fn, err := qc.GetFunc(n)
256+
if err != nil {
257+
return nil, err
258+
}
259+
table, err := qc.GetTable(&ast.TableName{
260+
Catalog: fn.ReturnType.Catalog,
261+
Schema: fn.ReturnType.Schema,
262+
Name: fn.ReturnType.Name,
263+
})
264+
if err != nil {
265+
return nil, err
266+
}
267+
tables = append(tables, table)
254268
case *ast.RangeSubselect:
255269
cols, err := outputColumns(qc, n.Subquery)
256270
if err != nil {

internal/compiler/query.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ import (
44
"github.com/kyleconroy/sqlc/internal/sql/ast"
55
)
66

7+
type Function struct {
8+
Rel *ast.FuncName
9+
ReturnType *ast.TypeName
10+
}
11+
712
type Table struct {
813
Rel *ast.TableName
914
Columns []*Column

internal/compiler/query_catalog.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package compiler
22

33
import (
4+
"fmt"
5+
46
"github.com/kyleconroy/sqlc/internal/sql/ast"
57
"github.com/kyleconroy/sqlc/internal/sql/catalog"
68
)
@@ -71,3 +73,17 @@ func (qc QueryCatalog) GetTable(rel *ast.TableName) (*Table, error) {
7173
}
7274
return &Table{Rel: rel, Columns: cols}, nil
7375
}
76+
77+
func (qc QueryCatalog) GetFunc(rel *ast.FuncName) (*Function, error) {
78+
funcs, err := qc.catalog.ListFuncsByName(rel)
79+
if err != nil {
80+
return nil, err
81+
}
82+
if len(funcs) == 0 {
83+
return nil, fmt.Errorf("function not found: %s", rel.Name)
84+
}
85+
return &Function{
86+
Rel: rel,
87+
ReturnType: funcs[0].ReturnType,
88+
}, nil
89+
}

internal/compiler/resolve.go

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,23 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*ast.RangeVar, args []paramRef
3131
return defaultName
3232
}
3333

34+
typeMap := map[string]map[string]map[string]*catalog.Column{}
35+
indexTable := func(table catalog.Table) error {
36+
tables = append(tables, table.Rel)
37+
if defaultTable == nil {
38+
defaultTable = table.Rel
39+
}
40+
if _, exists := typeMap[table.Rel.Schema]; !exists {
41+
typeMap[table.Rel.Schema] = map[string]map[string]*catalog.Column{}
42+
}
43+
typeMap[table.Rel.Schema][table.Rel.Name] = map[string]*catalog.Column{}
44+
for _, c := range table.Columns {
45+
cc := c
46+
typeMap[table.Rel.Schema][table.Rel.Name][c.Name] = cc
47+
}
48+
return nil
49+
}
50+
3451
for _, rv := range rvs {
3552
if rv.Relname == nil {
3653
continue
@@ -39,29 +56,16 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*ast.RangeVar, args []paramRef
3956
if err != nil {
4057
return nil, err
4158
}
42-
tables = append(tables, fqn)
43-
if defaultTable == nil {
44-
defaultTable = fqn
45-
}
46-
if rv.Alias == nil {
47-
continue
48-
}
49-
aliasMap[*rv.Alias.Aliasname] = fqn
50-
}
51-
52-
typeMap := map[string]map[string]map[string]*catalog.Column{}
53-
for _, fqn := range tables {
5459
table, err := c.GetTable(fqn)
5560
if err != nil {
5661
continue
5762
}
58-
if _, exists := typeMap[fqn.Schema]; !exists {
59-
typeMap[fqn.Schema] = map[string]map[string]*catalog.Column{}
63+
err = indexTable(table)
64+
if err != nil {
65+
return nil, err
6066
}
61-
typeMap[fqn.Schema][fqn.Name] = map[string]*catalog.Column{}
62-
for _, c := range table.Columns {
63-
cc := c
64-
typeMap[fqn.Schema][fqn.Name][c.Name] = cc
67+
if rv.Alias != nil {
68+
aliasMap[*rv.Alias.Aliasname] = fqn
6569
}
6670
}
6771

@@ -270,6 +274,23 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*ast.RangeVar, args []paramRef
270274
})
271275
}
272276

277+
if fun.ReturnType == nil {
278+
continue
279+
}
280+
281+
table, err := c.GetTable(&ast.TableName{
282+
Catalog: fun.ReturnType.Catalog,
283+
Schema: fun.ReturnType.Schema,
284+
Name: fun.ReturnType.Name,
285+
})
286+
if err != nil {
287+
// The return type wasn't a table.
288+
continue
289+
}
290+
err = indexTable(table)
291+
if err != nil {
292+
return nil, err
293+
}
273294
case *ast.ResTarget:
274295
if n.Name == nil {
275296
return nil, fmt.Errorf("*ast.ResTarget has nil name")

internal/endtoend/testdata/func_return/go/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/func_return/go/models.go

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/func_return/go/query.sql.go

Lines changed: 37 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
/* name: GetUsers :many */
2+
SELECT *
3+
FROM users_func()
4+
WHERE first_name != '';
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CREATE TABLE users (
2+
id integer,
3+
first_name varchar(255) NOT NULL
4+
);
5+
6+
CREATE FUNCTION users_func() RETURNS SETOF users AS $func$ BEGIN QUERY
7+
SELECT *
8+
FROM users
9+
END $func$ LANGUAGE plpgsql;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"path": "go",
6+
"name": "querytest",
7+
"schema": "schema.sql",
8+
"queries": "query.sql",
9+
"engine": "postgresql"
10+
}
11+
]
12+
}

0 commit comments

Comments
 (0)