Improve recur behavior

This commit is contained in:
Mike Cifelli 2017-11-17 19:14:59 -05:00
parent 6cf017734f
commit e2272fa976
6 changed files with 99 additions and 33 deletions

View File

@ -138,21 +138,7 @@ public class UserDefinedFunction extends LispFunction {
private SExpression evaluateBody() { private SExpression evaluateBody() {
SExpression lastEvaluation = NIL; SExpression lastEvaluation = NIL;
// recur sets indicator in execution context for (Cons expression = body; expression.isCons(); expression = (Cons) expression.getRest())
// eval can't be called when recur indicator is set - or error occurs (and recur indicator
// is cleared)
// recur checks to see if in function call, if not an error occurs (stack of objects
// (function call types) in execution context)
// if not, set recur indicator somewhere, (can't be in symbol table because LET creates
// scopes as well)
//
// on exception - clear function call counter and recur indicator in main loop
// recur evaluates its arguments (or not, according to function type) in its scope, then
// returns the list
//
// eval adds function call type to execution context stack of function calls when user
// defined function encountered
for (Cons expression = body; expression.isCons() /* and not recur indicator */; expression = (Cons) expression.getRest())
lastEvaluation = eval(expression.getFirst()); lastEvaluation = eval(expression.getFirst());
return lastEvaluation; return lastEvaluation;

View File

@ -72,6 +72,11 @@ public class EVAL extends LispFunction {
@Override @Override
public SExpression call(Cons argumentList) { public SExpression call(Cons argumentList) {
if (executionContext.isRecur()) {
executionContext.clearContext();
throw new RecurNotInTailPositionException();
}
argumentValidator.validate(argumentList); argumentValidator.validate(argumentList);
SExpression argument = argumentList.getFirst(); SExpression argument = argumentList.getFirst();
@ -115,15 +120,14 @@ public class EVAL extends LispFunction {
} }
private SExpression callFunction(LispFunction function, Cons argumentList) { private SExpression callFunction(LispFunction function, Cons argumentList) {
// if (executionContext.isRecur()) {
// executionContext.popFunctionCall(); // FIXME - clear this in repl?
// throw new RecurNotInTailPositionException();
// }
if (function.isArgumentListEvaluated()) if (function.isArgumentListEvaluated())
argumentList = evaluateArgumentList(argumentList); argumentList = evaluateArgumentList(argumentList);
if (executionContext.isRecur()) {
executionContext.clearContext();
throw new RecurNotInTailPositionException();
}
SExpression result = function.call(argumentList); SExpression result = function.call(argumentList);
if (function.isMacro()) if (function.isMacro())

View File

@ -26,12 +26,12 @@ public class PROGN extends LispSpecialFunction {
} }
private SExpression callTailRecursive(Cons argumentList, SExpression lastValue) { private SExpression callTailRecursive(Cons argumentList, SExpression lastValue) {
SExpression currentValue = eval(argumentList.getFirst());
Cons remainingValues = (Cons) argumentList.getRest();
if (argumentList.isNull()) if (argumentList.isNull())
return lastValue; return lastValue;
SExpression currentValue = eval(argumentList.getFirst());
Cons remainingValues = (Cons) argumentList.getRest();
return callTailRecursive(remainingValues, currentValue); return callTailRecursive(remainingValues, currentValue);
} }

View File

@ -24,20 +24,26 @@ public class RECUR extends LispSpecialFunction {
@Override @Override
public SExpression call(Cons argumentList) { public SExpression call(Cons argumentList) {
if (executionContext.isRecur()) if (executionContext.isRecurInitializing())
throw new NestedRecurException(); throw new NestedRecurException();
if (!executionContext.isInFunctionCall()) if (!executionContext.isInFunctionCall())
throw new RecurOutsideOfFunctionException(); throw new RecurOutsideOfFunctionException();
argumentValidator.validate(argumentList); argumentValidator.validate(argumentList);
executionContext.setRecur();
LispFunction currentFunction = executionContext.getCurrentFunction(); LispFunction currentFunction = executionContext.getCurrentFunction();
Cons recurArguments = argumentList; Cons recurArguments = argumentList;
if (currentFunction.isArgumentListEvaluated()) executionContext.setRecurInitializing();
recurArguments = evalRecurArgumentList(argumentList);
try {
if (currentFunction.isArgumentListEvaluated())
recurArguments = evalRecurArgumentList(argumentList);
} finally {
executionContext.clearRecurInitializing();
}
executionContext.setRecur();
return recurArguments; return recurArguments;
} }

View File

@ -15,12 +15,14 @@ public class ExecutionContext {
private SymbolTable scope; private SymbolTable scope;
private Stack<LispFunction> functionCalls; private Stack<LispFunction> functionCalls;
private boolean recurInitializing;
private boolean recur; private boolean recur;
private ExecutionContext() { private ExecutionContext() {
this.scope = new SymbolTable(); this.scope = new SymbolTable();
this.functionCalls = new Stack<>(); this.functionCalls = new Stack<>();
this.clearRecur(); this.clearRecur();
this.clearRecurInitializing();
} }
public SymbolTable getScope() { public SymbolTable getScope() {
@ -35,6 +37,7 @@ public class ExecutionContext {
this.scope = new SymbolTable(); this.scope = new SymbolTable();
this.functionCalls = new Stack<>(); this.functionCalls = new Stack<>();
this.clearRecur(); this.clearRecur();
this.clearRecurInitializing();
} }
public SExpression lookupSymbolValue(String symbolName) { public SExpression lookupSymbolValue(String symbolName) {
@ -73,4 +76,15 @@ public class ExecutionContext {
recur = false; recur = false;
} }
public boolean isRecurInitializing() {
return recurInitializing;
}
public void setRecurInitializing() {
recurInitializing = true;
}
public void clearRecurInitializing() {
recurInitializing = false;
}
} }

View File

@ -1,5 +1,6 @@
package function.builtin.special; package function.builtin.special;
import static org.junit.Assert.fail;
import static testutil.TestUtilities.assertSExpressionsMatch; import static testutil.TestUtilities.assertSExpressionsMatch;
import static testutil.TestUtilities.evaluateString; import static testutil.TestUtilities.evaluateString;
import static testutil.TestUtilities.parseString; import static testutil.TestUtilities.parseString;
@ -57,21 +58,74 @@ public class RECURTest extends SymbolAndFunctionCleaner {
evaluateString("(tail-recursive 900)"); evaluateString("(tail-recursive 900)");
} }
@Test(expected = RecurNotInTailPositionException.class)
public void recurInNonTailPosition2_ThrowsException() {
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
evaluateString("(tail-recursive 900)");
}
@Test @Test
public void recurCallsCurrentFunction() { public void recurCallsCurrentFunction() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
} }
@Test
public void recurCallsCurrentFunction_InBegin() {
evaluateString("(defun tail-recursive (n) (if (> n 1) (begin 1 2 (recur (- n 1))) 'PASS))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test @Test
public void recurCallsCurrentFunction_WithApply() { public void recurCallsCurrentFunction_WithApply() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))")); assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))"));
} }
@Test
// recur with funcall public void recurCallsCurrentFunction_WithFuncall() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(call 'tail-recursive '900)"));
}
@Test
public void recurWorksAfterFailure() {
evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))");
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
try {
evaluateString("(bad-tail-recursive 900)");
fail("expectedException");
} catch (RecurNotInTailPositionException e) {}
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test
public void recurWorksAfterFailure2() {
evaluateString("(defun bad-tail-recursive (n) (begin (recur) 2))");
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
try {
evaluateString("(bad-tail-recursive 900)");
fail("expectedException");
} catch (RecurNotInTailPositionException e) {}
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test
public void recurWorksAfterNestedFailure() {
evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))");
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
try {
evaluateString("(bad-tail-recursive 900)");
fail("expectedException");
} catch (NestedRecurException e) {}
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
// recur non-tail in apply call // recur non-tail in apply call
// recur with no args, alters global variable // recur with no args, alters global variable
@ -80,4 +134,6 @@ public class RECURTest extends SymbolAndFunctionCleaner {
// recur with nested anonymous function // recur with nested anonymous function
// test scope after failure in function, is it global again?
} }