Skip to content

Commit 96cf80f

Browse files
authored
Merge pull request #449 from mcarmonaa/feature/udf-uast-children
Feature/udf uast children
2 parents c642dce + a70d40b commit 96cf80f

File tree

4 files changed

+142
-26
lines changed

4 files changed

+142
-26
lines changed

docs/using-gitbase/functions.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ To make some common tasks easier for the user, there are some functions to inter
1212
|uast(blob, [lang, [xpath]])json_blob| returns an array of UAST nodes as blobs in semantic mode |
1313
|uast_mode(mode, blob, lang)json_blob| returns an array of UAST nodes as blobs specifying its language and mode (semantic, annotated or native) |
1414
|uast_xpath(json_blob, xpath)| performs an XPath query over the given UAST nodes |
15-
|uast_extract(json_blob, key)| extracts information identified by the given key from the uast nodes |
15+
|uast_extract(json_blob, key)| extracts information identified by the given key from the uast nodes |
16+
|uast_children(json_blob)| returns a flattened array of the children UAST nodes from each one of the UAST nodes in the given array |
1617

1718
## Standard functions
1819

internal/function/registry.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import "gopkg.in/src-d/go-mysql-server.v0/sql"
44

55
// Functions for gitbase queries.
66
var Functions = sql.Functions{
7-
"is_tag": sql.Function1(NewIsTag),
8-
"is_remote": sql.Function1(NewIsRemote),
9-
"language": sql.FunctionN(NewLanguage),
10-
"uast": sql.FunctionN(NewUAST),
11-
"uast_mode": sql.Function3(NewUASTMode),
12-
"uast_xpath": sql.Function2(NewUASTXPath),
13-
"uast_extract": sql.Function2(NewUASTExtract),
7+
"is_tag": sql.Function1(NewIsTag),
8+
"is_remote": sql.Function1(NewIsRemote),
9+
"language": sql.FunctionN(NewLanguage),
10+
"uast": sql.FunctionN(NewUAST),
11+
"uast_mode": sql.Function3(NewUASTMode),
12+
"uast_xpath": sql.Function2(NewUASTXPath),
13+
"uast_extract": sql.Function2(NewUASTExtract),
14+
"uast_children": sql.Function1(NewUASTChildren),
1415
}

internal/function/uast.go

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,12 @@ func (f *UASTXPath) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err er
334334
return nil, err
335335
}
336336

337-
for _, n := range ns {
338-
data, err := n.Marshal()
339-
if err != nil {
340-
return nil, err
341-
}
342-
result = append(result, data)
337+
m, err := marshalNodes(ns)
338+
if err != nil {
339+
return nil, err
343340
}
341+
342+
result = append(result, m...)
344343
}
345344

346345
return result, nil
@@ -359,12 +358,29 @@ func nodesFromBlobArray(data interface{}) ([]*uast.Node, error) {
359358
if err := node.Unmarshal(n.([]byte)); err != nil {
360359
return nil, err
361360
}
361+
362362
nodes[i] = node
363363
}
364364

365365
return nodes, nil
366366
}
367367

368+
func marshalNodes(nodes []*uast.Node) ([]interface{}, error) {
369+
m := make([]interface{}, 0, len(nodes))
370+
for _, n := range nodes {
371+
if n != nil {
372+
data, err := n.Marshal()
373+
if err != nil {
374+
return nil, err
375+
}
376+
377+
m = append(m, data)
378+
}
379+
}
380+
381+
return m, nil
382+
}
383+
368384
func (f UASTXPath) String() string {
369385
return fmt.Sprintf("uast_xpath(%s, %s)", f.Left, f.Right)
370386
}
@@ -459,18 +475,7 @@ func getUAST(
459475
}
460476
}
461477

462-
var result = make([]interface{}, 0, len(nodes))
463-
for _, n := range nodes {
464-
if n != nil {
465-
node, err := n.Marshal()
466-
if err != nil {
467-
return nil, err
468-
}
469-
result = append(result, node)
470-
}
471-
}
472-
473-
return result, nil
478+
return marshalNodes(nodes)
474479
}
475480

476481
// UASTExtract extracts keys from an UAST.
@@ -586,3 +591,73 @@ func (u *UASTExtract) TransformUp(f sql.TransformExprFunc) (sql.Expression, erro
586591

587592
return f(NewUASTExtract(left, rigth))
588593
}
594+
595+
// UASTChildren returns children from UAST nodes.
596+
type UASTChildren struct {
597+
expression.UnaryExpression
598+
}
599+
600+
// NewUASTChildren creates a new UASTExtract UDF.
601+
func NewUASTChildren(uast sql.Expression) sql.Expression {
602+
return &UASTChildren{expression.UnaryExpression{Child: uast}}
603+
}
604+
605+
// String implements the fmt.Stringer interface.
606+
func (u *UASTChildren) String() string {
607+
return fmt.Sprintf("uast_children(%s)", u.Child)
608+
}
609+
610+
// Type implements the sql.Expression interface.
611+
func (u *UASTChildren) Type() sql.Type {
612+
return sql.Array(sql.Blob)
613+
}
614+
615+
// TransformUp implements the sql.Expression interface.
616+
func (u *UASTChildren) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
617+
child, err := u.Child.TransformUp(f)
618+
if err != nil {
619+
return nil, err
620+
}
621+
622+
return f(NewUASTChildren(child))
623+
}
624+
625+
// Eval implements the sql.Expression interface.
626+
func (u *UASTChildren) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err error) {
627+
defer func() {
628+
if r := recover(); r != nil {
629+
err = fmt.Errorf("uast: unknown error: %s", r)
630+
}
631+
}()
632+
633+
span, ctx := ctx.Span("gitbase.UASTChildren")
634+
defer span.Finish()
635+
636+
child, err := u.Child.Eval(ctx, row)
637+
if err != nil {
638+
return nil, err
639+
}
640+
641+
if child == nil {
642+
return nil, nil
643+
}
644+
645+
nodes, err := nodesFromBlobArray(child)
646+
if err != nil {
647+
return nil, err
648+
}
649+
650+
children := flattenChildren(nodes)
651+
return marshalNodes(children)
652+
}
653+
654+
func flattenChildren(nodes []*uast.Node) []*uast.Node {
655+
children := []*uast.Node{}
656+
for _, n := range nodes {
657+
if len(n.Children) > 0 {
658+
children = append(children, n.Children...)
659+
}
660+
}
661+
662+
return children
663+
}

internal/function/uast_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,45 @@ func TestUASTExtract(t *testing.T) {
281281
}
282282
}
283283

284+
func TestUASTChildren(t *testing.T) {
285+
var require = require.New(t)
286+
287+
ctx, cleanup := setup(t)
288+
defer cleanup()
289+
290+
uasts, _ := bblfshFixtures(t, ctx)
291+
modes := []string{"semantic", "annotated", "native"}
292+
293+
for _, mode := range modes {
294+
root, ok := uasts[mode]
295+
require.True(ok)
296+
297+
nodes, err := nodesFromBlobArray(root)
298+
require.NoError(err)
299+
require.Len(nodes, 1)
300+
expected := nodes[0].Children
301+
302+
row := sql.NewRow(root)
303+
304+
fn := NewUASTChildren(
305+
expression.NewGetField(0, sql.Array(sql.Blob), "", false),
306+
)
307+
308+
children, err := fn.Eval(ctx, row)
309+
require.NoError(err)
310+
311+
nodes, err = nodesFromBlobArray(children)
312+
require.NoError(err)
313+
require.Len(nodes, len(expected))
314+
for i, n := range nodes {
315+
require.Equal(
316+
n.InternalType,
317+
expected[i].InternalType,
318+
)
319+
}
320+
}
321+
}
322+
284323
func assertUASTBlobs(t *testing.T, a, b interface{}) {
285324
t.Helper()
286325
require := require.New(t)

0 commit comments

Comments
 (0)