Skip to content

Commit 079c735

Browse files
committed
added functions
1 parent 606fb47 commit 079c735

File tree

7 files changed

+255
-29
lines changed

7 files changed

+255
-29
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ https://craftinginterpreters.com/
99
- [x] Parsing Expressions
1010
- [x] Evaluating Expressions
1111
- [x] Statements and State
12-
- [x] Control Flow
12+
- [x] Control Flow
13+
- [x] Functions

pkg/ast/expr.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,14 @@ type LogicalExpr struct {
6767

6868
func (l *LogicalExpr) Accept(v ExprVisitor) interface{} {
6969
return v.VisitLogicalExpr(l)
70+
}
71+
72+
type CallExpr struct {
73+
Callee Expr
74+
Paren scanner.Token
75+
Arguments []Expr
76+
}
77+
78+
func (c *CallExpr) Accept(v ExprVisitor) interface{} {
79+
return v.VisitCallExpr(c)
7080
}

pkg/ast/parser.go

Lines changed: 122 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ import (
1010
type parser struct {
1111
tokens []scanner.Token
1212
current int
13+
fdepth int
1314
err error
1415
}
1516

1617
func NewParser(tokens []scanner.Token) *parser {
17-
return &parser{tokens, 0, nil}
18+
return &parser{tokens, 0, 0, nil}
1819
}
1920

2021
func (p *parser) Parse() ([]Stmt, error) {
@@ -34,10 +35,14 @@ func (p *parser) declaration() Stmt {
3435
return p.varDeclaration()
3536
}
3637

38+
if p.match(scanner.FUN) {
39+
return p.funDeclaration("function")
40+
}
41+
3742
return p.statement()
3843
}
3944

40-
func (p *parser) varDeclaration() Stmt {
45+
func (p *parser) varDeclaration() *VarStmt {
4146
if !p.match(scanner.IDENTIFIER) {
4247
panic(fault.NewFault(p.tokens[p.current].Line, "expected variable name"))
4348
}
@@ -55,6 +60,51 @@ func (p *parser) varDeclaration() Stmt {
5560
return &VarStmt{&name, initializer}
5661
}
5762

63+
func (p *parser) funDeclaration(kind string) *FunStmt {
64+
if !p.match(scanner.IDENTIFIER) {
65+
message := fmt.Sprintf("expected %s name", kind)
66+
panic(fault.NewFault(p.tokens[p.current].Line, message))
67+
}
68+
name := p.tokens[p.current-1]
69+
70+
if !p.match(scanner.LEFT_PAREN) {
71+
message := fmt.Sprintf("expected '(' after %s name", kind)
72+
panic(fault.NewFault(p.tokens[p.current].Line, message))
73+
}
74+
75+
params := []*scanner.Token{}
76+
if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF {
77+
if !p.match(scanner.IDENTIFIER) {
78+
message := fmt.Sprintf("expected parameter name at %s", p.tokens[p.current].Lexeme)
79+
panic(fault.NewFault(p.tokens[p.current].Line, message))
80+
}
81+
params = append(params, &p.tokens[p.current-1])
82+
for p.match(scanner.COMMA) {
83+
if !p.match(scanner.IDENTIFIER) {
84+
message := fmt.Sprintf("expected parameter name at %s", p.tokens[p.current].Lexeme)
85+
panic(fault.NewFault(p.tokens[p.current].Line, message))
86+
}
87+
params = append(params, &p.tokens[p.current-1])
88+
if len(params) > 255 {
89+
panic(fault.NewFault(p.tokens[p.current].Line, "cannot have more than 255 parameters"))
90+
}
91+
}
92+
}
93+
94+
if !p.match(scanner.RIGHT_PAREN) {
95+
panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after parameter list"))
96+
}
97+
98+
if !p.match(scanner.LEFT_BRACE) {
99+
message := fmt.Sprintf("expected '{' before %s body", kind)
100+
panic(fault.NewFault(p.tokens[p.current].Line, message))
101+
}
102+
103+
p.fdepth++
104+
defer func() { p.fdepth-- }()
105+
return &FunStmt{&name, params, p.blockStatement()}
106+
}
107+
58108
func (p *parser) statement() Stmt {
59109
if p.match(scanner.PRINT) {
60110
return p.printStatement()
@@ -73,13 +123,17 @@ func (p *parser) statement() Stmt {
73123
}
74124

75125
if p.match(scanner.LEFT_BRACE) {
76-
return &BlockStmt{p.block()}
126+
return p.blockStatement()
127+
}
128+
129+
if p.match(scanner.RETURN) {
130+
return p.returnStatement()
77131
}
78132

79133
return p.exprStatement()
80134
}
81135

82-
func (p *parser) printStatement() Stmt {
136+
func (p *parser) printStatement() *PrintStmt {
83137
expr := p.expression()
84138

85139
if !p.match(scanner.SEMICOLON) {
@@ -89,7 +143,7 @@ func (p *parser) printStatement() Stmt {
89143
return &PrintStmt{expr}
90144
}
91145

92-
func (p *parser) ifStatement() Stmt {
146+
func (p *parser) ifStatement() *IfStmt {
93147
if !p.match(scanner.LEFT_PAREN) {
94148
panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after if"))
95149
}
@@ -123,15 +177,15 @@ func (p *parser) forStatement() Stmt {
123177
}
124178

125179
var condition Expr
126-
if p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.SEMICOLON {
180+
if p.tokens[p.current].TokenType != scanner.SEMICOLON && p.tokens[p.current].TokenType != scanner.EOF {
127181
condition = p.expression()
128182
}
129183
if !p.match(scanner.SEMICOLON) {
130184
panic(fault.NewFault(p.tokens[p.current].Line, "expected ';' after conditional expression"))
131185
}
132186

133187
var increment Expr
134-
if p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.RIGHT_PAREN {
188+
if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF {
135189
increment = p.expression()
136190
}
137191
if !p.match(scanner.RIGHT_PAREN) {
@@ -156,7 +210,7 @@ func (p *parser) forStatement() Stmt {
156210
return body
157211
}
158212

159-
func (p *parser) whileStatement() Stmt {
213+
func (p *parser) whileStatement() *WhileStmt {
160214
if !p.match(scanner.LEFT_PAREN) {
161215
panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after while"))
162216
}
@@ -169,21 +223,21 @@ func (p *parser) whileStatement() Stmt {
169223
return &WhileStmt{condition, p.statement()}
170224
}
171225

172-
func (p *parser) block() []Stmt {
226+
func (p *parser) blockStatement() *BlockStmt {
173227
stmts := []Stmt{}
174228

175-
for p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.RIGHT_BRACE {
229+
for p.tokens[p.current].TokenType != scanner.RIGHT_BRACE && p.tokens[p.current].TokenType != scanner.EOF {
176230
stmts = append(stmts, p.declaration())
177231
}
178232

179233
if !p.match(scanner.RIGHT_BRACE) {
180234
panic(fault.NewFault(p.tokens[p.current].Line, "expected '}' after block"))
181235
}
182236

183-
return stmts
237+
return &BlockStmt{stmts}
184238
}
185239

186-
func (p *parser) exprStatement() Stmt {
240+
func (p *parser) exprStatement() *ExprStmt {
187241
expr := p.expression()
188242

189243
if !p.match(scanner.SEMICOLON) {
@@ -193,6 +247,24 @@ func (p *parser) exprStatement() Stmt {
193247
return &ExprStmt{expr}
194248
}
195249

250+
func (p *parser) returnStatement() *ReturnStmt {
251+
keyword := p.tokens[p.current-1]
252+
if p.fdepth == 0 {
253+
panic(fault.NewFault(keyword.Line, "cannot return outside of a function"))
254+
}
255+
256+
var value Expr
257+
if p.tokens[p.current].TokenType != scanner.SEMICOLON && p.tokens[p.current].TokenType != scanner.EOF {
258+
value = p.expression()
259+
}
260+
261+
if !p.match(scanner.SEMICOLON) {
262+
panic(fault.NewFault(p.tokens[p.current].Line, "expected ';' after return statement"))
263+
}
264+
265+
return &ReturnStmt{&keyword, value}
266+
}
267+
196268
func (p *parser) expression() Expr {
197269
return p.assignment()
198270
}
@@ -293,7 +365,37 @@ func (p *parser) unary() Expr {
293365
return &UnaryExpr{&operator, right}
294366
}
295367

296-
return p.primary()
368+
return p.call()
369+
}
370+
371+
func (p *parser) call() Expr {
372+
expr := p.primary()
373+
374+
for p.match(scanner.LEFT_PAREN) {
375+
args, paren := p.arguments()
376+
expr = &CallExpr{expr, paren, args}
377+
}
378+
379+
return expr
380+
}
381+
382+
func (p *parser) arguments() ([]Expr, scanner.Token) {
383+
args := []Expr{}
384+
if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF {
385+
args = append(args, p.expression())
386+
for p.match(scanner.COMMA) {
387+
args = append(args, p.expression())
388+
if len(args) > 255 {
389+
panic(fault.NewFault(p.tokens[p.current].Line, "cannot have more than 255 arguments"))
390+
}
391+
}
392+
}
393+
394+
if !p.match(scanner.RIGHT_PAREN) {
395+
panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after argument list"))
396+
}
397+
398+
return args, p.tokens[p.current-1]
297399
}
298400

299401
func (p *parser) primary() Expr {
@@ -321,26 +423,25 @@ func (p *parser) primary() Expr {
321423

322424
if p.match(scanner.LEFT_PAREN) {
323425
e := p.expression()
324-
if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN {
325-
message := fmt.Sprintf("expected ')' after \"%s\"", p.tokens[p.current].Lexeme)
426+
if !p.match(scanner.RIGHT_PAREN) {
427+
message := fmt.Sprintf("expected ')' after '%s'", p.tokens[p.current-1].Lexeme)
326428
panic(fault.NewFault(p.tokens[p.current].Line, message))
327429
}
328-
p.current++
329430
return &GroupingExpr{e}
330431
}
331432

332-
message := fmt.Sprintf("expected expression at \"%s\"", p.tokens[p.current].Lexeme)
433+
message := fmt.Sprintf("expected expression at '%s'", p.tokens[p.current].Lexeme)
333434
panic(fault.NewFault(p.tokens[p.current].Line, message))
334435
}
335436

336437
func (p *parser) match(types ...int) bool {
337-
if p.tokens[p.current].TokenType == scanner.EOF {
438+
currentType := p.tokens[p.current].TokenType
439+
if currentType == scanner.EOF {
338440
return false
339441
}
340442

341-
actualType := p.tokens[p.current].TokenType
342443
for _, tokenType := range types {
343-
if actualType == tokenType {
444+
if currentType == tokenType {
344445
p.current++
345446
return true
346447
}
@@ -386,4 +487,4 @@ func (p *parser) synchronize() {
386487
p.current++
387488
}
388489
}
389-
}
490+
}

pkg/ast/stmt.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,23 @@ type WhileStmt struct {
5656

5757
func (w *WhileStmt) Accept(v StmtVisitor) interface{} {
5858
return v.VisitWhileStmt(w)
59+
}
60+
61+
type FunStmt struct {
62+
Name *scanner.Token
63+
Params []*scanner.Token
64+
Body *BlockStmt
65+
}
66+
67+
func (f *FunStmt) Accept(v StmtVisitor) interface{} {
68+
return v.VisitFunStmt(f)
69+
}
70+
71+
type ReturnStmt struct {
72+
Keyword *scanner.Token
73+
Value Expr
74+
}
75+
76+
func (r *ReturnStmt) Accept(v StmtVisitor) interface{} {
77+
return v.VisitReturnStmt(r)
5978
}

pkg/ast/visitor.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ type ExprVisitor interface {
88
VisitVariableExpr(v *VariableExpr) interface{}
99
VisitAssignExpr(a *AssignExpr) interface{}
1010
VisitLogicalExpr(l *LogicalExpr) interface{}
11+
VisitCallExpr(c *CallExpr) interface{}
1112
}
1213

1314
type StmtVisitor interface {
@@ -17,4 +18,6 @@ type StmtVisitor interface {
1718
VisitBlockStmt(b *BlockStmt) interface{}
1819
VisitIfStmt(i *IfStmt) interface{}
1920
VisitWhileStmt(w *WhileStmt) interface{}
21+
VisitFunStmt(f *FunStmt) interface{}
22+
VisitReturnStmt(r *ReturnStmt) interface{}
2023
}

pkg/interpreter/callable.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package interpreter
2+
3+
import (
4+
"fmt"
5+
"time"
6+
7+
"golox/pkg/ast"
8+
)
9+
10+
type callable interface {
11+
arity() int
12+
call(i *interpreter, args []interface{}) interface{}
13+
}
14+
15+
type clock struct{}
16+
17+
func (c *clock) arity() int { return 0 }
18+
19+
func (c *clock) call(i *interpreter, args []interface{}) interface{} {
20+
return float64(time.Now().UnixMilli() / 1000)
21+
}
22+
23+
func (c clock) String() string {
24+
return "<native function clock>"
25+
}
26+
27+
type function struct {
28+
declaration *ast.FunStmt
29+
}
30+
31+
func (f *function) arity() int { return len(f.declaration.Params) }
32+
33+
func (f *function) call(i *interpreter, args []interface{}) (value interface{}) {
34+
env := &environment{i.global, make(map[string]interface{})}
35+
for i := 0; i < f.arity(); i++ {
36+
env.define(f.declaration.Params[i].Lexeme, args[i])
37+
}
38+
39+
prev := i.env
40+
defer func() {
41+
i.env = prev
42+
r := recover()
43+
if err, ok := r.(error); ok {
44+
panic(err)
45+
} else {
46+
value = r
47+
}
48+
}()
49+
50+
i.env = env
51+
for _, stmt := range f.declaration.Body.Statements {
52+
stmt.Accept(i)
53+
}
54+
55+
return value
56+
}
57+
58+
func (f function) String() string {
59+
return fmt.Sprintf("<function %s >", f.declaration.Name.Lexeme)
60+
}

0 commit comments

Comments
 (0)