Last active
March 20, 2019 10:30
-
-
Save mrnugget/a2bc0794b7d1a249de77a19ea0807389 to your computer and use it in GitHub Desktop.
Fix for recursive closures that are defined in other functions. These break in version 1.0 of "Writing A Compiler In Go". This fix adds another opcode, OpGetSelf, and emits it whenever there's a reference to the currently executed function.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/ast/ast.go b/ast/ast.go | |
index 8db3b39..f0420f4 100644 | |
--- a/ast/ast.go | |
+++ b/ast/ast.go | |
@@ -2,6 +2,7 @@ package ast | |
import ( | |
"bytes" | |
+ "fmt" | |
"strings" | |
"github.com/mrnugget/monkey/token" | |
@@ -223,6 +224,7 @@ type FunctionLiteral struct { | |
Token token.Token // The 'fn' token | |
Parameters []*Identifier | |
Body *BlockStatement | |
+ Name string | |
} | |
func (fl *FunctionLiteral) expressionNode() {} | |
@@ -236,6 +238,9 @@ func (fl *FunctionLiteral) String() string { | |
} | |
out.WriteString(fl.TokenLiteral()) | |
+ if fl.Name != "" { | |
+ out.WriteString(fmt.Sprintf("<%s>", fl.Name)) | |
+ } | |
out.WriteString("(") | |
out.WriteString(strings.Join(params, ", ")) | |
out.WriteString(") ") | |
diff --git a/code/code.go b/code/code.go | |
index 8931973..443482b 100644 | |
--- a/code/code.go | |
+++ b/code/code.go | |
@@ -36,6 +36,7 @@ const ( | |
OpSetLocal // [idx] -- Store local variable | |
OpGetFree // [idx] -- Load a free variable from the current closure's Free store | |
OpGetBuiltin // [idx] -- Load a free variable from the current closure's Free store | |
+ OpGetSelf // [] -- Load the current closure onto the stack | |
OpCall // [n] -- Call function that sits on top of stack with n arguments | |
OpReturnValue // [] -- Returns from the function -- Value sits on stack | |
@@ -90,6 +91,7 @@ var definitions = map[Opcode]*Definition{ | |
OpSetLocal: {"OpSetLocal", []int{1}}, | |
OpGetFree: {"OpGetFree", []int{1}}, | |
OpGetBuiltin: {"OpGetBuiltin", []int{1}}, | |
+ OpGetSelf: {"OpGetSelf", []int{}}, | |
// [idx, lenFree, numLocals] -- Turn the CONSTANT at [idx] into a Closure and put on stack | |
OpClosure: {"OpClosure", []int{2, 1}}, | |
diff --git a/compiler/compiler.go b/compiler/compiler.go | |
index 478147d..0021f31 100644 | |
--- a/compiler/compiler.go | |
+++ b/compiler/compiler.go | |
@@ -254,6 +254,10 @@ func (c *Compiler) Compile(node ast.Node) error { | |
case *ast.FunctionLiteral: | |
c.enterScope() | |
+ if node.Name != "" { | |
+ c.symbolTable.DefineSelf(node.Name) | |
+ } | |
+ | |
for _, p := range node.Parameters { | |
c.symbolTable.Define(p.Value) | |
} | |
@@ -431,6 +435,8 @@ func (c *Compiler) loadSymbol(s Symbol) { | |
c.emit(code.OpGetBuiltin, s.Index) | |
case FreeScope: | |
c.emit(code.OpGetFree, s.Index) | |
+ case SelfScope: | |
+ c.emit(code.OpGetSelf) | |
} | |
} | |
diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go | |
index 45b168a..c3c4edb 100644 | |
--- a/compiler/compiler_test.go | |
+++ b/compiler/compiler_test.go | |
@@ -395,6 +395,106 @@ func TestFunctions(t *testing.T) { | |
runCompilerTests(t, tests) | |
} | |
+func TestRecursiveFunctions(t *testing.T) { | |
+ tests := []compilerTestCase{ | |
+ { | |
+ input: ` | |
+ let inner = fn(x) { | |
+ if (x == 0) { | |
+ return 0; | |
+ } else { | |
+ inner(x - 1); | |
+ } | |
+ }; | |
+ inner(1); | |
+ `, | |
+ expectedConstants: []interface{}{ | |
+ 0, | |
+ 0, | |
+ 1, | |
+ []code.Instructions{ | |
+ code.Make(code.OpGetLocal, 0), | |
+ code.Make(code.OpConstant, 0), | |
+ code.Make(code.OpEqual), | |
+ code.Make(code.OpJumpNotTruthy, 16), | |
+ code.Make(code.OpConstant, 1), | |
+ code.Make(code.OpReturnValue), | |
+ code.Make(code.OpJump, 25), | |
+ code.Make(code.OpGetSelf), | |
+ code.Make(code.OpGetLocal, 0), | |
+ code.Make(code.OpConstant, 2), | |
+ code.Make(code.OpSub), | |
+ code.Make(code.OpCall, 1), | |
+ code.Make(code.OpReturnValue), | |
+ }, | |
+ 1, | |
+ }, | |
+ expectedInstructions: []code.Instructions{ | |
+ code.Make(code.OpClosure, 3, 0), | |
+ code.Make(code.OpSetGlobal, 0), | |
+ code.Make(code.OpGetGlobal, 0), | |
+ code.Make(code.OpConstant, 4), | |
+ code.Make(code.OpCall, 1), | |
+ code.Make(code.OpPop), | |
+ }, | |
+ }, | |
+ { | |
+ input: ` | |
+ let wrapper = fn() { | |
+ let inner = fn(x) { | |
+ if (x == 0) { | |
+ return 0; | |
+ } else { | |
+ inner(x - 1); | |
+ } | |
+ }; | |
+ inner(1); | |
+ }; | |
+ wrapper(); | |
+ `, | |
+ expectedConstants: []interface{}{ | |
+ 0, | |
+ 0, | |
+ 1, | |
+ []code.Instructions{ | |
+ code.Make(code.OpGetLocal, 0), | |
+ code.Make(code.OpConstant, 0), | |
+ code.Make(code.OpEqual), | |
+ code.Make(code.OpJumpNotTruthy, 16), | |
+ code.Make(code.OpConstant, 1), | |
+ code.Make(code.OpReturnValue), | |
+ code.Make(code.OpJump, 25), | |
+ code.Make(code.OpGetSelf), | |
+ code.Make(code.OpGetLocal, 0), | |
+ code.Make(code.OpConstant, 2), | |
+ code.Make(code.OpSub), | |
+ code.Make(code.OpCall, 1), | |
+ code.Make(code.OpReturnValue), | |
+ }, | |
+ 1, | |
+ []code.Instructions{ | |
+ | |
+ code.Make(code.OpClosure, 3, 0), | |
+ code.Make(code.OpSetLocal, 0), | |
+ code.Make(code.OpGetLocal, 0), | |
+ code.Make(code.OpConstant, 4), | |
+ code.Make(code.OpCall, 1), | |
+ code.Make(code.OpReturnValue), | |
+ }, | |
+ }, | |
+ expectedInstructions: []code.Instructions{ | |
+ code.Make(code.OpClosure, 5, 0), | |
+ code.Make(code.OpSetGlobal, 0), | |
+ code.Make(code.OpGetGlobal, 0), | |
+ code.Make(code.OpCall, 0), | |
+ code.Make(code.OpPop), | |
+ }, | |
+ }, | |
+ } | |
+ | |
+ runCompilerTests(t, tests) | |
+} | |
+ | |
type compilerTestCase struct { | |
input string | |
expectedConstants []interface{} | |
diff --git a/compiler/symbol_table.go b/compiler/symbol_table.go | |
index 70d0280..0d6b4eb 100644 | |
--- a/compiler/symbol_table.go | |
+++ b/compiler/symbol_table.go | |
@@ -7,6 +7,7 @@ const ( | |
GlobalScope SymbolScope = "GLOBAL" | |
BuiltinScope SymbolScope = "BUILTIN" | |
FreeScope SymbolScope = "FREE" | |
+ SelfScope SymbolScope = "SELF" | |
) | |
type Symbol struct { | |
@@ -73,6 +74,12 @@ func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol { | |
return symbol | |
} | |
+func (s *SymbolTable) DefineSelf(name string) Symbol { | |
+ symbol := Symbol{Name: name, Index: 0, Scope: SelfScope} | |
+ s.store[name] = symbol | |
+ return symbol | |
+} | |
+ | |
func (s *SymbolTable) defineFree(original Symbol) Symbol { | |
s.FreeSymbols = append(s.FreeSymbols, original) | |
diff --git a/compiler/symbol_table_test.go b/compiler/symbol_table_test.go | |
index 8ceac0e..3b0315e 100644 | |
--- a/compiler/symbol_table_test.go | |
+++ b/compiler/symbol_table_test.go | |
@@ -300,3 +300,38 @@ func TestResolveUnresolvableFree(t *testing.T) { | |
} | |
} | |
} | |
+ | |
+func TestDefineAndResolveSelf(t *testing.T) { | |
+ expected := Symbol{Name: "a", Scope: SelfScope, Index: 0} | |
+ | |
+ global := NewSymbolTable() | |
+ global.DefineSelf("a") | |
+ | |
+ result, ok := global.Resolve(expected.Name) | |
+ if !ok { | |
+ t.Fatalf("self name %s not resolvable", expected.Name) | |
+ } | |
+ | |
+ if result != expected { | |
+ t.Errorf("expected %s to resolve to %+v, got=%+v", | |
+ expected.Name, expected, result) | |
+ } | |
+} | |
+ | |
+func TestShadowingSelf(t *testing.T) { | |
+ expected := Symbol{Name: "a", Scope: GlobalScope, Index: 0} | |
+ | |
+ global := NewSymbolTable() | |
+ global.DefineSelf(expected.Name) | |
+ global.Define(expected.Name) | |
+ | |
+ result, ok := global.Resolve(expected.Name) | |
+ if !ok { | |
+ t.Fatalf("self name %s not resolvable", expected.Name) | |
+ } | |
+ | |
+ if result != expected { | |
+ t.Errorf("expected %s to resolve to %+v, got=%+v", | |
+ expected.Name, expected, result) | |
+ } | |
+} | |
diff --git a/parser/parser.go b/parser/parser.go | |
index 94635b9..dbd581d 100644 | |
--- a/parser/parser.go | |
+++ b/parser/parser.go | |
@@ -172,6 +172,10 @@ func (p *Parser) parseLetStatement() *ast.LetStatement { | |
stmt.Value = p.parseExpression(LOWEST) | |
+ if fl, ok := stmt.Value.(*ast.FunctionLiteral); ok { | |
+ fl.Name = stmt.Name.Value | |
+ } | |
+ | |
if p.peekTokenIs(token.SEMICOLON) { | |
p.nextToken() | |
} | |
diff --git a/parser/parser_test.go b/parser/parser_test.go | |
index 9fccfce..98229ac 100644 | |
--- a/parser/parser_test.go | |
+++ b/parser/parser_test.go | |
@@ -587,6 +587,37 @@ func TestFunctionParameterParsing(t *testing.T) { | |
} | |
} | |
+func TestFunctionDefinitionParsing(t *testing.T) { | |
+ input := `let myFunction = fn() { };` | |
+ | |
+ l := lexer.New(input) | |
+ p := New(l) | |
+ program := p.ParseProgram() | |
+ checkParserErrors(t, p) | |
+ | |
+ if len(program.Statements) != 1 { | |
+ t.Fatalf("program.Body does not contain %d statements. got=%d\n", | |
+ 1, len(program.Statements)) | |
+ } | |
+ | |
+ stmt, ok := program.Statements[0].(*ast.LetStatement) | |
+ if !ok { | |
+ t.Fatalf("program.Statements[0] is not ast.LetStatement. got=%T", | |
+ program.Statements[0]) | |
+ } | |
+ | |
+ function, ok := stmt.Value.(*ast.FunctionLiteral) | |
+ if !ok { | |
+ t.Fatalf("stmt.Value is not ast.FunctionLiteral. got=%T", | |
+ stmt.Value) | |
+ } | |
+ | |
+ if function.Name != "myFunction" { | |
+ t.Fatalf("function literal name wrong. want 'myFunction', got=%q\n", | |
+ function.Name) | |
+ } | |
+} | |
+ | |
func TestCallExpressionParsing(t *testing.T) { | |
input := "add(1, 2 * 3, 4 + 5);" | |
diff --git a/vm/vm.go b/vm/vm.go | |
index 31b1699..8490912 100644 | |
--- a/vm/vm.go | |
+++ b/vm/vm.go | |
@@ -278,6 +278,13 @@ func (vm *VM) Run() error { | |
if err != nil { | |
return err | |
} | |
+ | |
+ case code.OpGetSelf: | |
+ currentClosure := vm.currentFrame().cl | |
+ err := vm.push(currentClosure) | |
+ if err != nil { | |
+ return err | |
+ } | |
} | |
if vm.trace { | |
diff --git a/vm/vm_test.go b/vm/vm_test.go | |
index aace5d6..10bcd3e 100644 | |
--- a/vm/vm_test.go | |
+++ b/vm/vm_test.go | |
@@ -564,6 +564,83 @@ func TestRecursiveFibonacci(t *testing.T) { | |
runVmTests(t, tests) | |
} | |
+func TestRecursiveFunctions(t *testing.T) { | |
+ tests := []vmTestCase{ | |
+ { | |
+ // This works | |
+ input: ` | |
+ let inner = fn(x) { | |
+ if (x == 0) { | |
+ return 0; | |
+ } else { | |
+ inner(x - 1); | |
+ } | |
+ }; | |
+ inner(1); | |
+ `, | |
+ expected: 0, | |
+ }, | |
+ { | |
+ // This also works | |
+ input: ` | |
+ let inner = fn(x) { | |
+ if (x == 0) { | |
+ return 0; | |
+ } else { | |
+ inner(x - 1); | |
+ } | |
+ }; | |
+ let wrapper = fn() { | |
+ inner(1); | |
+ }; | |
+ wrapper(); | |
+ `, | |
+ expected: 0, | |
+ }, | |
+ { | |
+ // This did _NOT_ work | |
+ input: ` | |
+ let wrapper = fn() { | |
+ let inner = fn(x) { | |
+ if (x == 0) { | |
+ return 0; | |
+ } else { | |
+ inner(x - 1); | |
+ } | |
+ }; | |
+ inner(1); | |
+ }; | |
+ wrapper(); | |
+ `, | |
+ expected: 0, | |
+ }, | |
+ { | |
+ // Test that shadowing still works | |
+ input: ` | |
+ let one = fn() { let one = 1; return one }; | |
+ one(); | |
+ `, | |
+ expected: 1, | |
+ }, | |
+ { | |
+ // Test that shadowing still works | |
+ input: ` | |
+ let wrapper = fn() { | |
+ let inner = fn(x) { | |
+ let inner = 2; | |
+ x + inner | |
+ }; | |
+ inner(1); | |
+ }; | |
+ wrapper(); | |
+ `, | |
+ expected: 3, | |
+ }, | |
+ } | |
+ | |
+ runVmTests(t, tests) | |
+} | |
+ | |
type vmTestCase struct { | |
input string | |
expected interface{} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment