diff --git a/src/function/UserDefinedFunction.java b/src/function/UserDefinedFunction.java index a7d0958..019fddb 100644 --- a/src/function/UserDefinedFunction.java +++ b/src/function/UserDefinedFunction.java @@ -2,14 +2,14 @@ package function; import static function.builtin.EVAL.eval; import static java.text.MessageFormat.format; -import static recursion.tail.TailCalls.done; +import static recursion.TailCalls.done; import static sexpression.Nil.NIL; import java.util.ArrayList; import error.LispException; -import recursion.tail.TailCall; -import recursion.tail.TailCalls; +import recursion.TailCall; +import recursion.TailCalls; import sexpression.Cons; import sexpression.SExpression; import sexpression.Symbol; diff --git a/src/function/builtin/EVAL.java b/src/function/builtin/EVAL.java index 0869e0c..5530956 100644 --- a/src/function/builtin/EVAL.java +++ b/src/function/builtin/EVAL.java @@ -12,6 +12,7 @@ import error.LispException; import function.ArgumentValidator; import function.FunctionNames; import function.LispFunction; +import function.builtin.special.RECUR.RecurNotInTailPositionException; import sexpression.BackquoteExpression; import sexpression.Cons; import sexpression.LambdaExpression; @@ -22,10 +23,17 @@ import table.ExecutionContext; @FunctionNames({ "EVAL" }) public class EVAL extends LispFunction { + private static ExecutionContext executionContext = ExecutionContext.getInstance(); + public static SExpression eval(SExpression sExpression) { Cons argumentList = makeList(sExpression); - return lookupFunction("EVAL").call(argumentList); + try { + return lookupFunction("EVAL").call(argumentList); + } catch (LispException e) { + executionContext.restoreGlobalScope(); + throw e; + } } public static Cons evalRecurArgumentList(Cons argumentList) { @@ -62,27 +70,28 @@ public class EVAL extends LispFunction { } private ArgumentValidator argumentValidator; - private ExecutionContext executionContext; public EVAL(String name) { this.argumentValidator = new ArgumentValidator(name); this.argumentValidator.setExactNumberOfArguments(1); - this.executionContext = ExecutionContext.getInstance(); } @Override public SExpression call(Cons argumentList) { - if (executionContext.isRecur()) { - executionContext.clearContext(); - throw new RecurNotInTailPositionException(); - } - + verifyNotRecurring(); argumentValidator.validate(argumentList); SExpression argument = argumentList.getFirst(); return evaluateExpression(argument); } + private void verifyNotRecurring() { + if (executionContext.isRecur()) { + executionContext.clearRecur(); + throw new RecurNotInTailPositionException(); + } + } + private SExpression evaluateExpression(SExpression argument) { if (argument.isList()) return evaluateList(argument); @@ -123,10 +132,7 @@ public class EVAL extends LispFunction { if (function.isArgumentListEvaluated()) argumentList = evaluateArgumentList(argumentList); - if (executionContext.isRecur()) { - executionContext.clearContext(); - throw new RecurNotInTailPositionException(); - } + verifyNotRecurring(); SExpression result = function.call(argumentList); @@ -211,14 +217,4 @@ public class EVAL extends LispFunction { } } - public static class RecurNotInTailPositionException extends LispException { - - private static final long serialVersionUID = 1L; - - @Override - public String getMessage() { - return "recur not in tail position"; - } - } - } diff --git a/src/function/builtin/cons/LENGTH.java b/src/function/builtin/cons/LENGTH.java index 3496bc9..ca1270e 100644 --- a/src/function/builtin/cons/LENGTH.java +++ b/src/function/builtin/cons/LENGTH.java @@ -1,15 +1,15 @@ package function.builtin.cons; import static function.builtin.cons.LIST.makeList; -import static recursion.tail.TailCalls.done; -import static recursion.tail.TailCalls.tailCall; +import static recursion.TailCalls.done; +import static recursion.TailCalls.tailCall; import java.math.BigInteger; import function.ArgumentValidator; import function.FunctionNames; import function.LispFunction; -import recursion.tail.TailCall; +import recursion.TailCall; import sexpression.Cons; import sexpression.LispNumber; diff --git a/src/function/builtin/special/RECUR.java b/src/function/builtin/special/RECUR.java index 210f86d..ecb0f52 100644 --- a/src/function/builtin/special/RECUR.java +++ b/src/function/builtin/special/RECUR.java @@ -5,7 +5,6 @@ import static function.builtin.EVAL.evalRecurArgumentList; import error.LispException; import function.ArgumentValidator; import function.FunctionNames; -import function.LispFunction; import function.LispSpecialFunction; import sexpression.Cons; import sexpression.SExpression; @@ -31,23 +30,31 @@ public class RECUR extends LispSpecialFunction { throw new RecurOutsideOfFunctionException(); argumentValidator.validate(argumentList); - LispFunction currentFunction = executionContext.getCurrentFunction(); + Cons recurArguments = getRecurArguments(argumentList); + executionContext.setRecur(); + + return recurArguments; + } + + private Cons getRecurArguments(Cons argumentList) { Cons recurArguments = argumentList; - executionContext.setRecurInitializing(); - try { - if (currentFunction.isArgumentListEvaluated()) + executionContext.setRecurInitializing(); + + if (isRecurArgumentListEvaluated()) recurArguments = evalRecurArgumentList(argumentList); } finally { executionContext.clearRecurInitializing(); } - executionContext.setRecur(); - return recurArguments; } + private boolean isRecurArgumentListEvaluated() { + return executionContext.getCurrentFunction().isArgumentListEvaluated(); + } + public static class RecurOutsideOfFunctionException extends LispException { private static final long serialVersionUID = 1L; @@ -68,4 +75,14 @@ public class RECUR extends LispSpecialFunction { } } + public static class RecurNotInTailPositionException extends LispException { + + private static final long serialVersionUID = 1L; + + @Override + public String getMessage() { + return "recur not in tail position"; + } + } + } diff --git a/src/recursion/tail/TailCall.java b/src/recursion/TailCall.java similarity index 95% rename from src/recursion/tail/TailCall.java rename to src/recursion/TailCall.java index 2c2edf7..260425f 100644 --- a/src/recursion/tail/TailCall.java +++ b/src/recursion/TailCall.java @@ -1,4 +1,4 @@ -package recursion.tail; +package recursion; import java.util.stream.Stream; diff --git a/src/recursion/tail/TailCalls.java b/src/recursion/TailCalls.java similarity index 95% rename from src/recursion/tail/TailCalls.java rename to src/recursion/TailCalls.java index 482043c..f5137b6 100644 --- a/src/recursion/tail/TailCalls.java +++ b/src/recursion/TailCalls.java @@ -1,4 +1,4 @@ -package recursion.tail; +package recursion; public class TailCalls { diff --git a/src/sexpression/Cons.java b/src/sexpression/Cons.java index da6a243..52fbf50 100644 --- a/src/sexpression/Cons.java +++ b/src/sexpression/Cons.java @@ -1,9 +1,9 @@ package sexpression; -import static recursion.tail.TailCalls.done; -import static recursion.tail.TailCalls.tailCall; +import static recursion.TailCalls.done; +import static recursion.TailCalls.tailCall; -import recursion.tail.TailCall; +import recursion.TailCall; @DisplayName("list") public class Cons extends SExpression { diff --git a/src/table/ExecutionContext.java b/src/table/ExecutionContext.java index 5c529a9..c39961b 100644 --- a/src/table/ExecutionContext.java +++ b/src/table/ExecutionContext.java @@ -33,6 +33,11 @@ public class ExecutionContext { this.scope = scope; } + public void restoreGlobalScope() { + while (!scope.isGlobal()) + scope = scope.getParent(); + } + public void clearContext() { this.scope = new SymbolTable(); this.functionCalls = new Stack<>(); diff --git a/test/function/builtin/EVALTest.java b/test/function/builtin/EVALTest.java index 7280f7e..66bb304 100644 --- a/test/function/builtin/EVALTest.java +++ b/test/function/builtin/EVALTest.java @@ -2,6 +2,7 @@ package function.builtin; import static function.builtin.EVAL.lookupSymbol; import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; import static sexpression.Nil.NIL; import static testutil.TestUtilities.assertIsErrorWithMessage; import static testutil.TestUtilities.assertSExpressionsMatch; @@ -18,6 +19,7 @@ import function.builtin.EVAL.UndefinedFunctionException; import function.builtin.EVAL.UndefinedSymbolException; import function.builtin.EVAL.UnmatchedAtSignException; import function.builtin.EVAL.UnmatchedCommaException; +import function.builtin.special.RECUR.RecurNotInTailPositionException; import testutil.SymbolAndFunctionCleaner; public class EVALTest extends SymbolAndFunctionCleaner { @@ -178,4 +180,50 @@ public class EVALTest extends SymbolAndFunctionCleaner { assertSExpressionsMatch(parseString("((1 2 3))"), evaluateString(input)); } + @Test + public void scopeRestoredAfterFailure_Let() { + evaluateString("(setq n 100)"); + try { + evaluateString("(let ((n 200)) (begin 1 2 3 y))"); + fail("expected exception"); + } catch (UndefinedSymbolException e) {} + + assertSExpressionsMatch(parseString("100"), evaluateString("n")); + } + + @Test + public void scopeRestoredAfterFailure_Defun() { + evaluateString("(setq n 100)"); + try { + evaluateString("(defun test (n) (begin 1 2 3 y))"); + evaluateString("(test 200)"); + fail("expected exception"); + } catch (UndefinedSymbolException e) {} + + assertSExpressionsMatch(parseString("100"), evaluateString("n")); + } + + @Test + public void scopeRestoredAfterFailure_Lambda() { + evaluateString("(setq n 100)"); + try { + evaluateString("((lambda (n) (begin 1 2 3 y)) 200)"); + fail("expected exception"); + } catch (UndefinedSymbolException e) {} + + assertSExpressionsMatch(parseString("100"), evaluateString("n")); + } + + @Test + public void scopeRestoredAfterFailure_Recur() { + evaluateString("(setq n 100)"); + try { + evaluateString("(defun tail-recursive (n) (begin (recur) 2))"); + evaluateString("(tail-recursive 200)"); + fail("expected exception"); + } catch (RecurNotInTailPositionException e) {} + + assertSExpressionsMatch(parseString("100"), evaluateString("n")); + } + } diff --git a/test/function/builtin/special/RECURTest.java b/test/function/builtin/special/RECURTest.java index 16f88ea..0a910fe 100644 --- a/test/function/builtin/special/RECURTest.java +++ b/test/function/builtin/special/RECURTest.java @@ -9,9 +9,10 @@ import org.junit.Test; import function.ArgumentValidator.BadArgumentTypeException; import function.ArgumentValidator.TooManyArgumentsException; -import function.builtin.EVAL.RecurNotInTailPositionException; import function.builtin.special.RECUR.NestedRecurException; +import function.builtin.special.RECUR.RecurNotInTailPositionException; import function.builtin.special.RECUR.RecurOutsideOfFunctionException; +import sexpression.SExpression; import testutil.SymbolAndFunctionCleaner; public class RECURTest extends SymbolAndFunctionCleaner { @@ -53,13 +54,13 @@ public class RECURTest extends SymbolAndFunctionCleaner { } @Test(expected = RecurNotInTailPositionException.class) - public void recurInNonTailPosition_ThrowsException() { + public void recurInNonTailPositionInArgumentList() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))"); evaluateString("(tail-recursive 900)"); } @Test(expected = RecurNotInTailPositionException.class) - public void recurInNonTailPosition2_ThrowsException() { + public void recurInNonTailPositionInBegin() { evaluateString("(defun tail-recursive (n) (begin (recur) 2))"); evaluateString("(tail-recursive 900)"); } @@ -76,6 +77,18 @@ public class RECURTest extends SymbolAndFunctionCleaner { assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } + @Test + public void recurInTailPositionWithApply() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (apply 'recur (list (- n 1)))))"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); + } + + @Test + public void recurInTailPositionWithFuncall() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (call 'recur (- n 1))))"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); + } + @Test public void recurCallsCurrentFunction_WithApply() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); @@ -126,14 +139,24 @@ public class RECURTest extends SymbolAndFunctionCleaner { assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } - // recur non-tail in apply call - // recur with no args, alters global variable + @Test + public void recurWithNoArgs_AltersGlobalVariable() { + evaluateString("(defun tail-recursive () (if (= n 0) 'PASS (begin (setq n (- n 1)) (recur))))"); + evaluateString("(setq n 200)"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive)")); + } - // recur with anonymous function + @Test + public void recurWithLambda() { + SExpression lambdaTailCall = evaluateString("((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020)"); + assertSExpressionsMatch(parseString("PASS"), lambdaTailCall); + } - // recur with nested anonymous function - - // test scope after failure in function, is it global again? + @Test + public void recurWithNestedLambda() { + evaluateString("(defun nested-tail () ((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020))"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(nested-tail)")); + } } diff --git a/test/table/ExecutionContextTest.java b/test/table/ExecutionContextTest.java index 3a374e0..a003e83 100644 --- a/test/table/ExecutionContextTest.java +++ b/test/table/ExecutionContextTest.java @@ -1,8 +1,10 @@ package table; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; import static sexpression.Nil.NIL; import static sexpression.Symbol.T; @@ -85,4 +87,18 @@ public class ExecutionContextTest { assertEquals(NIL, executionContext.lookupSymbolValue("shadowed")); } + @Test + public void restoreGlobalContext() { + SymbolTable global = executionContext.getScope(); + SymbolTable scope1 = new SymbolTable(global); + SymbolTable scope2 = new SymbolTable(scope1); + SymbolTable scope3 = new SymbolTable(scope2); + executionContext.setScope(scope3); + + assertThat(executionContext.getScope().isGlobal(), is(false)); + executionContext.restoreGlobalScope(); + assertThat(executionContext.getScope().isGlobal(), is(true)); + assertThat(executionContext.getScope(), is(global)); + } + }