From e2272fa97631c57a1f07495f66563f952f55e8df Mon Sep 17 00:00:00 2001 From: Mike Cifelli Date: Fri, 17 Nov 2017 19:14:59 -0500 Subject: [PATCH] Improve recur behavior --- src/function/UserDefinedFunction.java | 16 +---- src/function/builtin/EVAL.java | 16 +++-- src/function/builtin/special/PROGN.java | 6 +- src/function/builtin/special/RECUR.java | 16 +++-- src/table/ExecutionContext.java | 14 +++++ test/function/builtin/special/RECURTest.java | 64 ++++++++++++++++++-- 6 files changed, 99 insertions(+), 33 deletions(-) diff --git a/src/function/UserDefinedFunction.java b/src/function/UserDefinedFunction.java index 22b5367..a7d0958 100644 --- a/src/function/UserDefinedFunction.java +++ b/src/function/UserDefinedFunction.java @@ -138,21 +138,7 @@ public class UserDefinedFunction extends LispFunction { private SExpression evaluateBody() { SExpression lastEvaluation = NIL; - // recur sets indicator in execution context - // eval can't be called when recur indicator is set - or error occurs (and recur indicator - // is cleared) - // recur checks to see if in function call, if not an error occurs (stack of objects - // (function call types) in execution context) - // if not, set recur indicator somewhere, (can't be in symbol table because LET creates - // scopes as well) - // - // on exception - clear function call counter and recur indicator in main loop - // recur evaluates its arguments (or not, according to function type) in its scope, then - // returns the list - // - // eval adds function call type to execution context stack of function calls when user - // defined function encountered - for (Cons expression = body; expression.isCons() /* and not recur indicator */; expression = (Cons) expression.getRest()) + for (Cons expression = body; expression.isCons(); expression = (Cons) expression.getRest()) lastEvaluation = eval(expression.getFirst()); return lastEvaluation; diff --git a/src/function/builtin/EVAL.java b/src/function/builtin/EVAL.java index 3e8a4a1..0869e0c 100644 --- a/src/function/builtin/EVAL.java +++ b/src/function/builtin/EVAL.java @@ -72,6 +72,11 @@ public class EVAL extends LispFunction { @Override public SExpression call(Cons argumentList) { + if (executionContext.isRecur()) { + executionContext.clearContext(); + throw new RecurNotInTailPositionException(); + } + argumentValidator.validate(argumentList); SExpression argument = argumentList.getFirst(); @@ -115,15 +120,14 @@ public class EVAL extends LispFunction { } private SExpression callFunction(LispFunction function, Cons argumentList) { - - // if (executionContext.isRecur()) { - // executionContext.popFunctionCall(); // FIXME - clear this in repl? - // throw new RecurNotInTailPositionException(); - // } - if (function.isArgumentListEvaluated()) argumentList = evaluateArgumentList(argumentList); + if (executionContext.isRecur()) { + executionContext.clearContext(); + throw new RecurNotInTailPositionException(); + } + SExpression result = function.call(argumentList); if (function.isMacro()) diff --git a/src/function/builtin/special/PROGN.java b/src/function/builtin/special/PROGN.java index 02f0af9..4759e27 100644 --- a/src/function/builtin/special/PROGN.java +++ b/src/function/builtin/special/PROGN.java @@ -26,12 +26,12 @@ public class PROGN extends LispSpecialFunction { } private SExpression callTailRecursive(Cons argumentList, SExpression lastValue) { - SExpression currentValue = eval(argumentList.getFirst()); - Cons remainingValues = (Cons) argumentList.getRest(); - if (argumentList.isNull()) return lastValue; + SExpression currentValue = eval(argumentList.getFirst()); + Cons remainingValues = (Cons) argumentList.getRest(); + return callTailRecursive(remainingValues, currentValue); } diff --git a/src/function/builtin/special/RECUR.java b/src/function/builtin/special/RECUR.java index 0dc302b..210f86d 100644 --- a/src/function/builtin/special/RECUR.java +++ b/src/function/builtin/special/RECUR.java @@ -24,20 +24,26 @@ public class RECUR extends LispSpecialFunction { @Override public SExpression call(Cons argumentList) { - if (executionContext.isRecur()) + if (executionContext.isRecurInitializing()) throw new NestedRecurException(); if (!executionContext.isInFunctionCall()) throw new RecurOutsideOfFunctionException(); argumentValidator.validate(argumentList); - - executionContext.setRecur(); LispFunction currentFunction = executionContext.getCurrentFunction(); Cons recurArguments = argumentList; - if (currentFunction.isArgumentListEvaluated()) - recurArguments = evalRecurArgumentList(argumentList); + executionContext.setRecurInitializing(); + + try { + if (currentFunction.isArgumentListEvaluated()) + recurArguments = evalRecurArgumentList(argumentList); + } finally { + executionContext.clearRecurInitializing(); + } + + executionContext.setRecur(); return recurArguments; } diff --git a/src/table/ExecutionContext.java b/src/table/ExecutionContext.java index 8303dab..5c529a9 100644 --- a/src/table/ExecutionContext.java +++ b/src/table/ExecutionContext.java @@ -15,12 +15,14 @@ public class ExecutionContext { private SymbolTable scope; private Stack functionCalls; + private boolean recurInitializing; private boolean recur; private ExecutionContext() { this.scope = new SymbolTable(); this.functionCalls = new Stack<>(); this.clearRecur(); + this.clearRecurInitializing(); } public SymbolTable getScope() { @@ -35,6 +37,7 @@ public class ExecutionContext { this.scope = new SymbolTable(); this.functionCalls = new Stack<>(); this.clearRecur(); + this.clearRecurInitializing(); } public SExpression lookupSymbolValue(String symbolName) { @@ -73,4 +76,15 @@ public class ExecutionContext { recur = false; } + public boolean isRecurInitializing() { + return recurInitializing; + } + + public void setRecurInitializing() { + recurInitializing = true; + } + + public void clearRecurInitializing() { + recurInitializing = false; + } } diff --git a/test/function/builtin/special/RECURTest.java b/test/function/builtin/special/RECURTest.java index 024edd4..16f88ea 100644 --- a/test/function/builtin/special/RECURTest.java +++ b/test/function/builtin/special/RECURTest.java @@ -1,5 +1,6 @@ package function.builtin.special; +import static org.junit.Assert.fail; import static testutil.TestUtilities.assertSExpressionsMatch; import static testutil.TestUtilities.evaluateString; import static testutil.TestUtilities.parseString; @@ -57,21 +58,74 @@ public class RECURTest extends SymbolAndFunctionCleaner { evaluateString("(tail-recursive 900)"); } + @Test(expected = RecurNotInTailPositionException.class) + public void recurInNonTailPosition2_ThrowsException() { + evaluateString("(defun tail-recursive (n) (begin (recur) 2))"); + evaluateString("(tail-recursive 900)"); + } + @Test public void recurCallsCurrentFunction() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } + @Test + public void recurCallsCurrentFunction_InBegin() { + evaluateString("(defun tail-recursive (n) (if (> n 1) (begin 1 2 (recur (- n 1))) 'PASS))"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); + } + @Test public void recurCallsCurrentFunction_WithApply() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))")); } - - - // recur with funcall - + + @Test + public void recurCallsCurrentFunction_WithFuncall() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(call 'tail-recursive '900)")); + } + + @Test + public void recurWorksAfterFailure() { + evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))"); + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + + try { + evaluateString("(bad-tail-recursive 900)"); + fail("expectedException"); + } catch (RecurNotInTailPositionException e) {} + + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); + } + + @Test + public void recurWorksAfterFailure2() { + evaluateString("(defun bad-tail-recursive (n) (begin (recur) 2))"); + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + + try { + evaluateString("(bad-tail-recursive 900)"); + fail("expectedException"); + } catch (RecurNotInTailPositionException e) {} + + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); + } + + @Test + public void recurWorksAfterNestedFailure() { + evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))"); + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + + try { + evaluateString("(bad-tail-recursive 900)"); + fail("expectedException"); + } catch (NestedRecurException e) {} + + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); + } // recur non-tail in apply call // recur with no args, alters global variable @@ -80,4 +134,6 @@ public class RECURTest extends SymbolAndFunctionCleaner { // recur with nested anonymous function + // test scope after failure in function, is it global again? + }