Narrow the behavior of RECUR

This commit is contained in:
Mike Cifelli 2017-11-17 17:57:14 -05:00
parent b4229c6ac1
commit 6cf017734f
6 changed files with 92 additions and 29 deletions

View File

@ -17,7 +17,7 @@
(defun many-years-with-no-interest-rate () (defun many-years-with-no-interest-rate ()
(setq compounder (interest-compounder 1000 0)) (setq compounder (interest-compounder 1000 0))
(call compounder :move-forward-years 83) (call compounder :move-forward-years 1803)
(call assertions :assert= 1000 (call compounder :get-principal))) (call assertions :assert= 1000 (call compounder :get-principal)))
(defun no-years-with-positive-interest-rate () (defun no-years-with-positive-interest-rate ()

View File

@ -27,7 +27,7 @@
(setq principal (setq principal
(+ principal (+ principal
(call static :percent-of-number principal interest-rate))) (call static :percent-of-number principal interest-rate)))
(call private :compound-interest (- years 1))))))) (recur (- years 1)))))))
(setq public (setq public
(dlambda (dlambda

View File

@ -89,17 +89,16 @@ public class UserDefinedFunction extends LispFunction {
@Override @Override
public SExpression call(Cons argumentList) { public SExpression call(Cons argumentList) {
return callTailRecursive(argumentList).invoke(); executionContext.pushFunctionCall(this);
SExpression result = callTailRecursive(argumentList).invoke();
executionContext.popFunctionCall();
return result;
} }
private TailCall<SExpression> callTailRecursive(Cons argumentList) { private TailCall<SExpression> callTailRecursive(Cons argumentList) {
argumentValidator.validate(argumentList); argumentValidator.validate(argumentList);
// if recur
// clear recur indicator
// validate function list
// do not evaluate arguments - will have been done already if necessary
// tailCall(evaluatedArguments)
SExpression result = evaluateInFunctionScope(argumentList); SExpression result = evaluateInFunctionScope(argumentList);
if (executionContext.isRecur()) { if (executionContext.isRecur()) {

View File

@ -12,7 +12,6 @@ import error.LispException;
import function.ArgumentValidator; import function.ArgumentValidator;
import function.FunctionNames; import function.FunctionNames;
import function.LispFunction; import function.LispFunction;
import function.UserDefinedFunction;
import sexpression.BackquoteExpression; import sexpression.BackquoteExpression;
import sexpression.Cons; import sexpression.Cons;
import sexpression.LambdaExpression; import sexpression.LambdaExpression;
@ -29,7 +28,7 @@ public class EVAL extends LispFunction {
return lookupFunction("EVAL").call(argumentList); return lookupFunction("EVAL").call(argumentList);
} }
public static Cons evalArgumentList(Cons argumentList) { public static Cons evalRecurArgumentList(Cons argumentList) {
return ((EVAL) lookupFunction("EVAL")).evaluateArgumentList(argumentList); return ((EVAL) lookupFunction("EVAL")).evaluateArgumentList(argumentList);
} }
@ -116,8 +115,11 @@ public class EVAL extends LispFunction {
} }
private SExpression callFunction(LispFunction function, Cons argumentList) { private SExpression callFunction(LispFunction function, Cons argumentList) {
if (function instanceof UserDefinedFunction)
executionContext.pushFunctionCall(function); // if (executionContext.isRecur()) {
// executionContext.popFunctionCall(); // FIXME - clear this in repl?
// throw new RecurNotInTailPositionException();
// }
if (function.isArgumentListEvaluated()) if (function.isArgumentListEvaluated())
argumentList = evaluateArgumentList(argumentList); argumentList = evaluateArgumentList(argumentList);
@ -127,9 +129,6 @@ public class EVAL extends LispFunction {
if (function.isMacro()) if (function.isMacro())
result = eval(result); result = eval(result);
if (function instanceof UserDefinedFunction)
executionContext.popFunctionCall();
return result; return result;
} }
@ -208,4 +207,14 @@ public class EVAL extends LispFunction {
} }
} }
public static class RecurNotInTailPositionException extends LispException {
private static final long serialVersionUID = 1L;
@Override
public String getMessage() {
return "recur not in tail position";
}
}
} }

View File

@ -1,6 +1,6 @@
package function.builtin.special; package function.builtin.special;
import static function.builtin.EVAL.evalArgumentList; import static function.builtin.EVAL.evalRecurArgumentList;
import error.LispException; import error.LispException;
import function.ArgumentValidator; import function.ArgumentValidator;
@ -24,19 +24,20 @@ public class RECUR extends LispSpecialFunction {
@Override @Override
public SExpression call(Cons argumentList) { public SExpression call(Cons argumentList) {
if (executionContext.isRecur())
throw new NestedRecurException();
if (!executionContext.isInFunctionCall())
throw new RecurOutsideOfFunctionException();
argumentValidator.validate(argumentList); argumentValidator.validate(argumentList);
if (!executionContext.isInFunctionCall()) { executionContext.setRecur();
throw new RecurOutsideOfFunctionException();
}
LispFunction currentFunction = executionContext.getCurrentFunction(); LispFunction currentFunction = executionContext.getCurrentFunction();
Cons recurArguments = argumentList; Cons recurArguments = argumentList;
if (currentFunction.isArgumentListEvaluated()) if (currentFunction.isArgumentListEvaluated())
recurArguments = evalArgumentList(argumentList); recurArguments = evalRecurArgumentList(argumentList);
executionContext.setRecur();
return recurArguments; return recurArguments;
} }
@ -51,4 +52,14 @@ public class RECUR extends LispSpecialFunction {
} }
} }
public static class NestedRecurException extends LispException {
private static final long serialVersionUID = 1L;
@Override
public String getMessage() {
return "nested call to recur";
}
}
} }

View File

@ -7,6 +7,9 @@ import static testutil.TestUtilities.parseString;
import org.junit.Test; import org.junit.Test;
import function.ArgumentValidator.BadArgumentTypeException; import function.ArgumentValidator.BadArgumentTypeException;
import function.ArgumentValidator.TooManyArgumentsException;
import function.builtin.EVAL.RecurNotInTailPositionException;
import function.builtin.special.RECUR.NestedRecurException;
import function.builtin.special.RECUR.RecurOutsideOfFunctionException; import function.builtin.special.RECUR.RecurOutsideOfFunctionException;
import testutil.SymbolAndFunctionCleaner; import testutil.SymbolAndFunctionCleaner;
@ -24,16 +27,57 @@ public class RECURTest extends SymbolAndFunctionCleaner {
evaluateString("(recur)"); evaluateString("(recur)");
} }
@Test
public void recurCallsCurrentFunction() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test(expected = BadArgumentTypeException.class) @Test(expected = BadArgumentTypeException.class)
public void recurInSpecialFunction_DoesNotEvaluateArguments() { public void recurInSpecialFunction_DoesNotEvaluateArguments() {
evaluateString("(define-special tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); evaluateString("(define-special tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
evaluateString("(tail-recursive 900)"); evaluateString("(tail-recursive 900)");
} }
@Test(expected = BadArgumentTypeException.class)
public void recurInMacro_DoesNotEvaluateArguments() {
evaluateString("(defmacro tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
evaluateString("(tail-recursive 900)");
}
@Test(expected = NestedRecurException.class)
public void nestedRecur_ThrowsException() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))");
evaluateString("(tail-recursive 900)");
}
@Test(expected = TooManyArgumentsException.class)
public void functionCallValidatesRecurArguments() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1) 23)))");
evaluateString("(tail-recursive 900)");
}
@Test(expected = RecurNotInTailPositionException.class)
public void recurInNonTailPosition_ThrowsException() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))");
evaluateString("(tail-recursive 900)");
}
@Test
public void recurCallsCurrentFunction() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test
public void recurCallsCurrentFunction_WithApply() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))"));
}
// recur with funcall
// recur non-tail in apply call
// recur with no args, alters global variable
// recur with anonymous function
// recur with nested anonymous function
} }