diff --git a/lisp/finance/interest-compounder-test.lisp b/lisp/finance/interest-compounder-test.lisp index 0ed6f46..5d1d82a 100644 --- a/lisp/finance/interest-compounder-test.lisp +++ b/lisp/finance/interest-compounder-test.lisp @@ -17,7 +17,7 @@ (defun many-years-with-no-interest-rate () (setq compounder (interest-compounder 1000 0)) - (call compounder :move-forward-years 83) + (call compounder :move-forward-years 1803) (call assertions :assert= 1000 (call compounder :get-principal))) (defun no-years-with-positive-interest-rate () diff --git a/lisp/finance/interest-compounder.lisp b/lisp/finance/interest-compounder.lisp index 75d70ae..ccbb2ca 100644 --- a/lisp/finance/interest-compounder.lisp +++ b/lisp/finance/interest-compounder.lisp @@ -27,7 +27,7 @@ (setq principal (+ principal (call static :percent-of-number principal interest-rate))) - (call private :compound-interest (- years 1))))))) + (recur (- years 1))))))) (setq public (dlambda diff --git a/src/function/UserDefinedFunction.java b/src/function/UserDefinedFunction.java index dee63ea..22b5367 100644 --- a/src/function/UserDefinedFunction.java +++ b/src/function/UserDefinedFunction.java @@ -89,17 +89,16 @@ public class UserDefinedFunction extends LispFunction { @Override public SExpression call(Cons argumentList) { - return callTailRecursive(argumentList).invoke(); + executionContext.pushFunctionCall(this); + SExpression result = callTailRecursive(argumentList).invoke(); + executionContext.popFunctionCall(); + + return result; } private TailCall callTailRecursive(Cons argumentList) { argumentValidator.validate(argumentList); - // if recur - // clear recur indicator - // validate function list - // do not evaluate arguments - will have been done already if necessary - // tailCall(evaluatedArguments) SExpression result = evaluateInFunctionScope(argumentList); if (executionContext.isRecur()) { diff --git a/src/function/builtin/EVAL.java b/src/function/builtin/EVAL.java index 2d6490d..3e8a4a1 100644 --- a/src/function/builtin/EVAL.java +++ b/src/function/builtin/EVAL.java @@ -12,7 +12,6 @@ import error.LispException; import function.ArgumentValidator; import function.FunctionNames; import function.LispFunction; -import function.UserDefinedFunction; import sexpression.BackquoteExpression; import sexpression.Cons; import sexpression.LambdaExpression; @@ -29,7 +28,7 @@ public class EVAL extends LispFunction { return lookupFunction("EVAL").call(argumentList); } - public static Cons evalArgumentList(Cons argumentList) { + public static Cons evalRecurArgumentList(Cons argumentList) { return ((EVAL) lookupFunction("EVAL")).evaluateArgumentList(argumentList); } @@ -116,8 +115,11 @@ public class EVAL extends LispFunction { } private SExpression callFunction(LispFunction function, Cons argumentList) { - if (function instanceof UserDefinedFunction) - executionContext.pushFunctionCall(function); + + // if (executionContext.isRecur()) { + // executionContext.popFunctionCall(); // FIXME - clear this in repl? + // throw new RecurNotInTailPositionException(); + // } if (function.isArgumentListEvaluated()) argumentList = evaluateArgumentList(argumentList); @@ -127,9 +129,6 @@ public class EVAL extends LispFunction { if (function.isMacro()) result = eval(result); - if (function instanceof UserDefinedFunction) - executionContext.popFunctionCall(); - return result; } @@ -208,4 +207,14 @@ public class EVAL extends LispFunction { } } + public static class RecurNotInTailPositionException extends LispException { + + private static final long serialVersionUID = 1L; + + @Override + public String getMessage() { + return "recur not in tail position"; + } + } + } diff --git a/src/function/builtin/special/RECUR.java b/src/function/builtin/special/RECUR.java index 01de941..0dc302b 100644 --- a/src/function/builtin/special/RECUR.java +++ b/src/function/builtin/special/RECUR.java @@ -1,6 +1,6 @@ package function.builtin.special; -import static function.builtin.EVAL.evalArgumentList; +import static function.builtin.EVAL.evalRecurArgumentList; import error.LispException; import function.ArgumentValidator; @@ -24,19 +24,20 @@ public class RECUR extends LispSpecialFunction { @Override public SExpression call(Cons argumentList) { + if (executionContext.isRecur()) + throw new NestedRecurException(); + + if (!executionContext.isInFunctionCall()) + throw new RecurOutsideOfFunctionException(); + argumentValidator.validate(argumentList); - if (!executionContext.isInFunctionCall()) { - throw new RecurOutsideOfFunctionException(); - } - + executionContext.setRecur(); LispFunction currentFunction = executionContext.getCurrentFunction(); Cons recurArguments = argumentList; if (currentFunction.isArgumentListEvaluated()) - recurArguments = evalArgumentList(argumentList); - - executionContext.setRecur(); + recurArguments = evalRecurArgumentList(argumentList); return recurArguments; } @@ -51,4 +52,14 @@ public class RECUR extends LispSpecialFunction { } } + public static class NestedRecurException extends LispException { + + private static final long serialVersionUID = 1L; + + @Override + public String getMessage() { + return "nested call to recur"; + } + } + } diff --git a/test/function/builtin/special/RECURTest.java b/test/function/builtin/special/RECURTest.java index 6afc46d..024edd4 100644 --- a/test/function/builtin/special/RECURTest.java +++ b/test/function/builtin/special/RECURTest.java @@ -7,6 +7,9 @@ import static testutil.TestUtilities.parseString; import org.junit.Test; import function.ArgumentValidator.BadArgumentTypeException; +import function.ArgumentValidator.TooManyArgumentsException; +import function.builtin.EVAL.RecurNotInTailPositionException; +import function.builtin.special.RECUR.NestedRecurException; import function.builtin.special.RECUR.RecurOutsideOfFunctionException; import testutil.SymbolAndFunctionCleaner; @@ -24,16 +27,57 @@ public class RECURTest extends SymbolAndFunctionCleaner { evaluateString("(recur)"); } - @Test - public void recurCallsCurrentFunction() { - evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); - assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); - } - @Test(expected = BadArgumentTypeException.class) public void recurInSpecialFunction_DoesNotEvaluateArguments() { evaluateString("(define-special tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); evaluateString("(tail-recursive 900)"); } + @Test(expected = BadArgumentTypeException.class) + public void recurInMacro_DoesNotEvaluateArguments() { + evaluateString("(defmacro tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + evaluateString("(tail-recursive 900)"); + } + + @Test(expected = NestedRecurException.class) + public void nestedRecur_ThrowsException() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))"); + evaluateString("(tail-recursive 900)"); + } + + @Test(expected = TooManyArgumentsException.class) + public void functionCallValidatesRecurArguments() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1) 23)))"); + evaluateString("(tail-recursive 900)"); + } + + @Test(expected = RecurNotInTailPositionException.class) + public void recurInNonTailPosition_ThrowsException() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))"); + evaluateString("(tail-recursive 900)"); + } + + @Test + public void recurCallsCurrentFunction() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); + } + + @Test + public void recurCallsCurrentFunction_WithApply() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))")); + } + + + // recur with funcall + + // recur non-tail in apply call + + // recur with no args, alters global variable + + // recur with anonymous function + + // recur with nested anonymous function + }