package function.builtin.special; import static org.junit.Assert.fail; import static testutil.TestUtilities.assertSExpressionsMatch; import static testutil.TestUtilities.evaluateString; import static testutil.TestUtilities.parseString; import org.junit.Test; import function.ArgumentValidator.BadArgumentTypeException; import function.ArgumentValidator.TooManyArgumentsException; import function.builtin.special.RECUR.NestedRecurException; import function.builtin.special.RECUR.RecurNotInTailPositionException; import function.builtin.special.RECUR.RecurOutsideOfFunctionException; import sexpression.SExpression; import testutil.SymbolAndFunctionCleaner; public class RECURTest extends SymbolAndFunctionCleaner { @Test(expected = RecurOutsideOfFunctionException.class) public void recurOutsideOfFunction_ThrowsException() { evaluateString("(recur)"); } @Test(expected = RecurOutsideOfFunctionException.class) public void recurOutsideOfFunction_AfterFunctionCall_ThrowsException() { evaluateString("(defun f (n) (if (= n 0) 'ZERO n))"); evaluateString("(f 2)"); evaluateString("(recur)"); } @Test(expected = BadArgumentTypeException.class) public void recurInSpecialFunction_DoesNotEvaluateArguments() { evaluateString("(define-special tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); evaluateString("(tail-recursive 900)"); } @Test(expected = BadArgumentTypeException.class) public void recurInMacro_DoesNotEvaluateArguments() { evaluateString("(defmacro tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); evaluateString("(tail-recursive 900)"); } @Test(expected = NestedRecurException.class) public void nestedRecur_ThrowsException() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))"); evaluateString("(tail-recursive 900)"); } @Test(expected = TooManyArgumentsException.class) public void functionCallValidatesRecurArguments() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1) 23)))"); evaluateString("(tail-recursive 900)"); } @Test(expected = RecurNotInTailPositionException.class) public void recurInNonTailPositionInArgumentList() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))"); evaluateString("(tail-recursive 900)"); } @Test(expected = RecurNotInTailPositionException.class) public void recurInNonTailPositionInBegin() { evaluateString("(defun tail-recursive (n) (begin (recur) 2))"); evaluateString("(tail-recursive 900)"); } @Test public void recurCallsCurrentFunction() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } @Test public void recurCallsCurrentFunction_InBegin() { evaluateString("(defun tail-recursive (n) (if (> n 1) (begin 1 2 (recur (- n 1))) 'PASS))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } @Test public void recurInTailPositionWithApply() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (apply 'recur (list (- n 1)))))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } @Test public void recurInTailPositionWithFuncall() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (call 'recur (- n 1))))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } @Test public void recurCallsCurrentFunction_WithApply() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))")); } @Test public void recurCallsCurrentFunction_WithFuncall() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(call 'tail-recursive '900)")); } @Test public void recurWorksAfterFailure() { evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))"); evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); try { evaluateString("(bad-tail-recursive 900)"); fail("expectedException"); } catch (RecurNotInTailPositionException e) {} assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } @Test public void recurWorksAfterFailure2() { evaluateString("(defun bad-tail-recursive (n) (begin (recur) 2))"); evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); try { evaluateString("(bad-tail-recursive 900)"); fail("expectedException"); } catch (RecurNotInTailPositionException e) {} assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } @Test public void recurWorksAfterNestedFailure() { evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))"); evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); try { evaluateString("(bad-tail-recursive 900)"); fail("expectedException"); } catch (NestedRecurException e) {} assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); } @Test public void recurWithNoArgs_AltersGlobalVariable() { evaluateString("(defun tail-recursive () (if (= n 0) 'PASS (begin (setq n (- n 1)) (recur))))"); evaluateString("(setq n 200)"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive)")); } @Test public void recurWithLambda() { SExpression lambdaTailCall = evaluateString("((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020)"); assertSExpressionsMatch(parseString("PASS"), lambdaTailCall); } @Test public void recurWithNestedLambda() { evaluateString("(defun nested-tail () ((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020))"); assertSExpressionsMatch(parseString("PASS"), evaluateString("(nested-tail)")); } }