Improve recur behavior
This commit is contained in:
parent
6cf017734f
commit
e2272fa976
|
@ -138,21 +138,7 @@ public class UserDefinedFunction extends LispFunction {
|
||||||
private SExpression evaluateBody() {
|
private SExpression evaluateBody() {
|
||||||
SExpression lastEvaluation = NIL;
|
SExpression lastEvaluation = NIL;
|
||||||
|
|
||||||
// recur sets indicator in execution context
|
for (Cons expression = body; expression.isCons(); expression = (Cons) expression.getRest())
|
||||||
// eval can't be called when recur indicator is set - or error occurs (and recur indicator
|
|
||||||
// is cleared)
|
|
||||||
// recur checks to see if in function call, if not an error occurs (stack of objects
|
|
||||||
// (function call types) in execution context)
|
|
||||||
// if not, set recur indicator somewhere, (can't be in symbol table because LET creates
|
|
||||||
// scopes as well)
|
|
||||||
//
|
|
||||||
// on exception - clear function call counter and recur indicator in main loop
|
|
||||||
// recur evaluates its arguments (or not, according to function type) in its scope, then
|
|
||||||
// returns the list
|
|
||||||
//
|
|
||||||
// eval adds function call type to execution context stack of function calls when user
|
|
||||||
// defined function encountered
|
|
||||||
for (Cons expression = body; expression.isCons() /* and not recur indicator */; expression = (Cons) expression.getRest())
|
|
||||||
lastEvaluation = eval(expression.getFirst());
|
lastEvaluation = eval(expression.getFirst());
|
||||||
|
|
||||||
return lastEvaluation;
|
return lastEvaluation;
|
||||||
|
|
|
@ -72,6 +72,11 @@ public class EVAL extends LispFunction {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SExpression call(Cons argumentList) {
|
public SExpression call(Cons argumentList) {
|
||||||
|
if (executionContext.isRecur()) {
|
||||||
|
executionContext.clearContext();
|
||||||
|
throw new RecurNotInTailPositionException();
|
||||||
|
}
|
||||||
|
|
||||||
argumentValidator.validate(argumentList);
|
argumentValidator.validate(argumentList);
|
||||||
SExpression argument = argumentList.getFirst();
|
SExpression argument = argumentList.getFirst();
|
||||||
|
|
||||||
|
@ -115,15 +120,14 @@ public class EVAL extends LispFunction {
|
||||||
}
|
}
|
||||||
|
|
||||||
private SExpression callFunction(LispFunction function, Cons argumentList) {
|
private SExpression callFunction(LispFunction function, Cons argumentList) {
|
||||||
|
|
||||||
// if (executionContext.isRecur()) {
|
|
||||||
// executionContext.popFunctionCall(); // FIXME - clear this in repl?
|
|
||||||
// throw new RecurNotInTailPositionException();
|
|
||||||
// }
|
|
||||||
|
|
||||||
if (function.isArgumentListEvaluated())
|
if (function.isArgumentListEvaluated())
|
||||||
argumentList = evaluateArgumentList(argumentList);
|
argumentList = evaluateArgumentList(argumentList);
|
||||||
|
|
||||||
|
if (executionContext.isRecur()) {
|
||||||
|
executionContext.clearContext();
|
||||||
|
throw new RecurNotInTailPositionException();
|
||||||
|
}
|
||||||
|
|
||||||
SExpression result = function.call(argumentList);
|
SExpression result = function.call(argumentList);
|
||||||
|
|
||||||
if (function.isMacro())
|
if (function.isMacro())
|
||||||
|
|
|
@ -26,12 +26,12 @@ public class PROGN extends LispSpecialFunction {
|
||||||
}
|
}
|
||||||
|
|
||||||
private SExpression callTailRecursive(Cons argumentList, SExpression lastValue) {
|
private SExpression callTailRecursive(Cons argumentList, SExpression lastValue) {
|
||||||
SExpression currentValue = eval(argumentList.getFirst());
|
|
||||||
Cons remainingValues = (Cons) argumentList.getRest();
|
|
||||||
|
|
||||||
if (argumentList.isNull())
|
if (argumentList.isNull())
|
||||||
return lastValue;
|
return lastValue;
|
||||||
|
|
||||||
|
SExpression currentValue = eval(argumentList.getFirst());
|
||||||
|
Cons remainingValues = (Cons) argumentList.getRest();
|
||||||
|
|
||||||
return callTailRecursive(remainingValues, currentValue);
|
return callTailRecursive(remainingValues, currentValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,20 +24,26 @@ public class RECUR extends LispSpecialFunction {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SExpression call(Cons argumentList) {
|
public SExpression call(Cons argumentList) {
|
||||||
if (executionContext.isRecur())
|
if (executionContext.isRecurInitializing())
|
||||||
throw new NestedRecurException();
|
throw new NestedRecurException();
|
||||||
|
|
||||||
if (!executionContext.isInFunctionCall())
|
if (!executionContext.isInFunctionCall())
|
||||||
throw new RecurOutsideOfFunctionException();
|
throw new RecurOutsideOfFunctionException();
|
||||||
|
|
||||||
argumentValidator.validate(argumentList);
|
argumentValidator.validate(argumentList);
|
||||||
|
|
||||||
executionContext.setRecur();
|
|
||||||
LispFunction currentFunction = executionContext.getCurrentFunction();
|
LispFunction currentFunction = executionContext.getCurrentFunction();
|
||||||
Cons recurArguments = argumentList;
|
Cons recurArguments = argumentList;
|
||||||
|
|
||||||
if (currentFunction.isArgumentListEvaluated())
|
executionContext.setRecurInitializing();
|
||||||
recurArguments = evalRecurArgumentList(argumentList);
|
|
||||||
|
try {
|
||||||
|
if (currentFunction.isArgumentListEvaluated())
|
||||||
|
recurArguments = evalRecurArgumentList(argumentList);
|
||||||
|
} finally {
|
||||||
|
executionContext.clearRecurInitializing();
|
||||||
|
}
|
||||||
|
|
||||||
|
executionContext.setRecur();
|
||||||
|
|
||||||
return recurArguments;
|
return recurArguments;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,12 +15,14 @@ public class ExecutionContext {
|
||||||
|
|
||||||
private SymbolTable scope;
|
private SymbolTable scope;
|
||||||
private Stack<LispFunction> functionCalls;
|
private Stack<LispFunction> functionCalls;
|
||||||
|
private boolean recurInitializing;
|
||||||
private boolean recur;
|
private boolean recur;
|
||||||
|
|
||||||
private ExecutionContext() {
|
private ExecutionContext() {
|
||||||
this.scope = new SymbolTable();
|
this.scope = new SymbolTable();
|
||||||
this.functionCalls = new Stack<>();
|
this.functionCalls = new Stack<>();
|
||||||
this.clearRecur();
|
this.clearRecur();
|
||||||
|
this.clearRecurInitializing();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SymbolTable getScope() {
|
public SymbolTable getScope() {
|
||||||
|
@ -35,6 +37,7 @@ public class ExecutionContext {
|
||||||
this.scope = new SymbolTable();
|
this.scope = new SymbolTable();
|
||||||
this.functionCalls = new Stack<>();
|
this.functionCalls = new Stack<>();
|
||||||
this.clearRecur();
|
this.clearRecur();
|
||||||
|
this.clearRecurInitializing();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SExpression lookupSymbolValue(String symbolName) {
|
public SExpression lookupSymbolValue(String symbolName) {
|
||||||
|
@ -73,4 +76,15 @@ public class ExecutionContext {
|
||||||
recur = false;
|
recur = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isRecurInitializing() {
|
||||||
|
return recurInitializing;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setRecurInitializing() {
|
||||||
|
recurInitializing = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void clearRecurInitializing() {
|
||||||
|
recurInitializing = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package function.builtin.special;
|
package function.builtin.special;
|
||||||
|
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
import static testutil.TestUtilities.assertSExpressionsMatch;
|
import static testutil.TestUtilities.assertSExpressionsMatch;
|
||||||
import static testutil.TestUtilities.evaluateString;
|
import static testutil.TestUtilities.evaluateString;
|
||||||
import static testutil.TestUtilities.parseString;
|
import static testutil.TestUtilities.parseString;
|
||||||
|
@ -57,21 +58,74 @@ public class RECURTest extends SymbolAndFunctionCleaner {
|
||||||
evaluateString("(tail-recursive 900)");
|
evaluateString("(tail-recursive 900)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test(expected = RecurNotInTailPositionException.class)
|
||||||
|
public void recurInNonTailPosition2_ThrowsException() {
|
||||||
|
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
|
||||||
|
evaluateString("(tail-recursive 900)");
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void recurCallsCurrentFunction() {
|
public void recurCallsCurrentFunction() {
|
||||||
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
||||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recurCallsCurrentFunction_InBegin() {
|
||||||
|
evaluateString("(defun tail-recursive (n) (if (> n 1) (begin 1 2 (recur (- n 1))) 'PASS))");
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void recurCallsCurrentFunction_WithApply() {
|
public void recurCallsCurrentFunction_WithApply() {
|
||||||
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
||||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))"));
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
// recur with funcall
|
public void recurCallsCurrentFunction_WithFuncall() {
|
||||||
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(call 'tail-recursive '900)"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recurWorksAfterFailure() {
|
||||||
|
evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))");
|
||||||
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
||||||
|
|
||||||
|
try {
|
||||||
|
evaluateString("(bad-tail-recursive 900)");
|
||||||
|
fail("expectedException");
|
||||||
|
} catch (RecurNotInTailPositionException e) {}
|
||||||
|
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recurWorksAfterFailure2() {
|
||||||
|
evaluateString("(defun bad-tail-recursive (n) (begin (recur) 2))");
|
||||||
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
||||||
|
|
||||||
|
try {
|
||||||
|
evaluateString("(bad-tail-recursive 900)");
|
||||||
|
fail("expectedException");
|
||||||
|
} catch (RecurNotInTailPositionException e) {}
|
||||||
|
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recurWorksAfterNestedFailure() {
|
||||||
|
evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))");
|
||||||
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
||||||
|
|
||||||
|
try {
|
||||||
|
evaluateString("(bad-tail-recursive 900)");
|
||||||
|
fail("expectedException");
|
||||||
|
} catch (NestedRecurException e) {}
|
||||||
|
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
|
}
|
||||||
// recur non-tail in apply call
|
// recur non-tail in apply call
|
||||||
|
|
||||||
// recur with no args, alters global variable
|
// recur with no args, alters global variable
|
||||||
|
@ -80,4 +134,6 @@ public class RECURTest extends SymbolAndFunctionCleaner {
|
||||||
|
|
||||||
// recur with nested anonymous function
|
// recur with nested anonymous function
|
||||||
|
|
||||||
|
// test scope after failure in function, is it global again?
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue