package function.builtin.special import function.ArgumentValidator.BadArgumentTypeException import function.ArgumentValidator.TooManyArgumentsException import function.builtin.special.Recur.NestedRecurException import function.builtin.special.Recur.RecurNotInTailPositionException import function.builtin.special.Recur.RecurOutsideOfFunctionException import org.junit.jupiter.api.Assertions.assertThrows import org.junit.jupiter.api.Test import org.junit.jupiter.api.fail import testutil.LispTestInstance import testutil.SymbolAndFunctionCleaner import testutil.TestUtilities.assertIsErrorWithMessage import testutil.TestUtilities.assertSExpressionsMatch import testutil.TestUtilities.evaluateString import testutil.TestUtilities.parseString @LispTestInstance class RecurTest : SymbolAndFunctionCleaner() { @Test fun recurOutsideOfFunction_ThrowsException() { assertThrows(RecurOutsideOfFunctionException::class.java) { evaluateString("(recur)") } } @Test fun recurOutsideOfFunction_AfterFunctionCall_ThrowsException() { evaluateString("(defun f (n) (if (= n 0) 'ZERO n))") evaluateString("(f 2)") assertThrows(RecurOutsideOfFunctionException::class.java) { evaluateString("(recur)") } } @Test fun recurInSpecialFunction_DoesNotEvaluateArguments() { evaluateString("(define-special tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))") assertThrows(BadArgumentTypeException::class.java) { evaluateString("(tail-recursive 900)") } } @Test fun recurInMacro_DoesNotEvaluateArguments() { evaluateString("(defmacro tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))") assertThrows(BadArgumentTypeException::class.java) { evaluateString("(tail-recursive 900)") } } @Test fun nestedRecur_ThrowsException() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))") assertThrows(NestedRecurException::class.java) { evaluateString("(tail-recursive 900)") } } @Test fun functionCallValidatesRecurArguments() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1) 23)))") assertThrows(TooManyArgumentsException::class.java) { evaluateString("(tail-recursive 900)") } } @Test fun recurInNonTailPositionInArgumentList() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))") assertThrows(RecurNotInTailPositionException::class.java) { evaluateString("(tail-recursive 900)") } } @Test fun recurInNonTailPositionInBegin() { evaluateString("(defun tail-recursive (n) (begin (recur) 2))") assertThrows(RecurNotInTailPositionException::class.java) { evaluateString("(tail-recursive 900)") } } @Test fun recurInNonTailPositionInApply() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (apply 'list (recur (- n 1)))))") assertThrows(RecurNotInTailPositionException::class.java) { evaluateString("(tail-recursive 900)") } } @Test fun recurCallsCurrentFunction() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))") assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")) } @Test fun recurCallsCurrentFunction_InBegin() { evaluateString("(defun tail-recursive (n) (if (> n 1) (begin 1 2 (recur (- n 1))) 'PASS))") assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")) } @Test fun recurInTailPositionWithApply() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (apply 'recur (list (- n 1)))))") assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")) } @Test fun recurInTailPositionWithFuncall() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (call 'recur (- n 1))))") assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")) } @Test fun recurCallsCurrentFunction_WithApply() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))") assertSExpressionsMatch(parseString("PASS"), evaluateString("(apply 'tail-recursive '(900))")) } @Test fun recurCallsCurrentFunction_WithFuncall() { evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))") assertSExpressionsMatch(parseString("PASS"), evaluateString("(call 'tail-recursive '900)")) } @Test fun recurWorksAfterFailure() { evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))") evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))") try { evaluateString("(bad-tail-recursive 900)") fail("expectedException") } catch (e: RecurNotInTailPositionException) { } assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")) } @Test fun recurWorksAfterFailure2() { evaluateString("(defun bad-tail-recursive (n) (begin (recur) 2))") evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))") try { evaluateString("(bad-tail-recursive 900)") fail("expectedException") } catch (e: RecurNotInTailPositionException) { } assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")) } @Test fun recurWorksAfterNestedFailure() { evaluateString("(defun bad-tail-recursive (n) (if (= n 0) 'PASS (recur (recur (- n 1)))))") evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))") try { evaluateString("(bad-tail-recursive 900)") fail("expectedException") } catch (e: NestedRecurException) { } assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")) } @Test fun recurWithNoArgs_AltersGlobalVariable() { evaluateString("(defun tail-recursive () (if (= n 0) 'PASS (begin (setq n (- n 1)) (recur))))") evaluateString("(setq n 200)") assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive)")) } @Test fun recurWithLambda() { val lambdaTailCall = evaluateString("((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020)") assertSExpressionsMatch(parseString("PASS"), lambdaTailCall) } @Test fun recurWithNestedLambda() { evaluateString("(defun nested-tail () ((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020))") assertSExpressionsMatch(parseString("PASS"), evaluateString("(nested-tail)")) } @Test fun recurWithNestedTailRecursiveFunction() { evaluateString("(defun one (n) (if (= n 0) 0 (recur (- n 1))))") evaluateString("(defun two (n) (if (= n 0) 'PASS (recur (one n))))") assertSExpressionsMatch(parseString("PASS"), evaluateString("(two 20)")) } @Test fun nestedRecurException_HasCorrectAttributes() { assertIsErrorWithMessage(NestedRecurException()) } @Test fun recurOutsideOfFunctionException_HasCorrectAttributes() { assertIsErrorWithMessage(RecurOutsideOfFunctionException()) } @Test fun recurNotInTailPositionException_HasCorrectAttributes() { assertIsErrorWithMessage(RecurNotInTailPositionException()) } }