Restore scope properly after errors
This commit is contained in:
parent
e2272fa976
commit
aeb3074750
|
@ -2,14 +2,14 @@ package function;
|
|||
|
||||
import static function.builtin.EVAL.eval;
|
||||
import static java.text.MessageFormat.format;
|
||||
import static recursion.tail.TailCalls.done;
|
||||
import static recursion.TailCalls.done;
|
||||
import static sexpression.Nil.NIL;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
import error.LispException;
|
||||
import recursion.tail.TailCall;
|
||||
import recursion.tail.TailCalls;
|
||||
import recursion.TailCall;
|
||||
import recursion.TailCalls;
|
||||
import sexpression.Cons;
|
||||
import sexpression.SExpression;
|
||||
import sexpression.Symbol;
|
||||
|
|
|
@ -12,6 +12,7 @@ import error.LispException;
|
|||
import function.ArgumentValidator;
|
||||
import function.FunctionNames;
|
||||
import function.LispFunction;
|
||||
import function.builtin.special.RECUR.RecurNotInTailPositionException;
|
||||
import sexpression.BackquoteExpression;
|
||||
import sexpression.Cons;
|
||||
import sexpression.LambdaExpression;
|
||||
|
@ -22,10 +23,17 @@ import table.ExecutionContext;
|
|||
@FunctionNames({ "EVAL" })
|
||||
public class EVAL extends LispFunction {
|
||||
|
||||
private static ExecutionContext executionContext = ExecutionContext.getInstance();
|
||||
|
||||
public static SExpression eval(SExpression sExpression) {
|
||||
Cons argumentList = makeList(sExpression);
|
||||
|
||||
return lookupFunction("EVAL").call(argumentList);
|
||||
try {
|
||||
return lookupFunction("EVAL").call(argumentList);
|
||||
} catch (LispException e) {
|
||||
executionContext.restoreGlobalScope();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public static Cons evalRecurArgumentList(Cons argumentList) {
|
||||
|
@ -62,27 +70,28 @@ public class EVAL extends LispFunction {
|
|||
}
|
||||
|
||||
private ArgumentValidator argumentValidator;
|
||||
private ExecutionContext executionContext;
|
||||
|
||||
public EVAL(String name) {
|
||||
this.argumentValidator = new ArgumentValidator(name);
|
||||
this.argumentValidator.setExactNumberOfArguments(1);
|
||||
this.executionContext = ExecutionContext.getInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SExpression call(Cons argumentList) {
|
||||
if (executionContext.isRecur()) {
|
||||
executionContext.clearContext();
|
||||
throw new RecurNotInTailPositionException();
|
||||
}
|
||||
|
||||
verifyNotRecurring();
|
||||
argumentValidator.validate(argumentList);
|
||||
SExpression argument = argumentList.getFirst();
|
||||
|
||||
return evaluateExpression(argument);
|
||||
}
|
||||
|
||||
private void verifyNotRecurring() {
|
||||
if (executionContext.isRecur()) {
|
||||
executionContext.clearRecur();
|
||||
throw new RecurNotInTailPositionException();
|
||||
}
|
||||
}
|
||||
|
||||
private SExpression evaluateExpression(SExpression argument) {
|
||||
if (argument.isList())
|
||||
return evaluateList(argument);
|
||||
|
@ -123,10 +132,7 @@ public class EVAL extends LispFunction {
|
|||
if (function.isArgumentListEvaluated())
|
||||
argumentList = evaluateArgumentList(argumentList);
|
||||
|
||||
if (executionContext.isRecur()) {
|
||||
executionContext.clearContext();
|
||||
throw new RecurNotInTailPositionException();
|
||||
}
|
||||
verifyNotRecurring();
|
||||
|
||||
SExpression result = function.call(argumentList);
|
||||
|
||||
|
@ -211,14 +217,4 @@ public class EVAL extends LispFunction {
|
|||
}
|
||||
}
|
||||
|
||||
public static class RecurNotInTailPositionException extends LispException {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@Override
|
||||
public String getMessage() {
|
||||
return "recur not in tail position";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
package function.builtin.cons;
|
||||
|
||||
import static function.builtin.cons.LIST.makeList;
|
||||
import static recursion.tail.TailCalls.done;
|
||||
import static recursion.tail.TailCalls.tailCall;
|
||||
import static recursion.TailCalls.done;
|
||||
import static recursion.TailCalls.tailCall;
|
||||
|
||||
import java.math.BigInteger;
|
||||
|
||||
import function.ArgumentValidator;
|
||||
import function.FunctionNames;
|
||||
import function.LispFunction;
|
||||
import recursion.tail.TailCall;
|
||||
import recursion.TailCall;
|
||||
import sexpression.Cons;
|
||||
import sexpression.LispNumber;
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import static function.builtin.EVAL.evalRecurArgumentList;
|
|||
import error.LispException;
|
||||
import function.ArgumentValidator;
|
||||
import function.FunctionNames;
|
||||
import function.LispFunction;
|
||||
import function.LispSpecialFunction;
|
||||
import sexpression.Cons;
|
||||
import sexpression.SExpression;
|
||||
|
@ -31,23 +30,31 @@ public class RECUR extends LispSpecialFunction {
|
|||
throw new RecurOutsideOfFunctionException();
|
||||
|
||||
argumentValidator.validate(argumentList);
|
||||
LispFunction currentFunction = executionContext.getCurrentFunction();
|
||||
Cons recurArguments = getRecurArguments(argumentList);
|
||||
executionContext.setRecur();
|
||||
|
||||
return recurArguments;
|
||||
}
|
||||
|
||||
private Cons getRecurArguments(Cons argumentList) {
|
||||
Cons recurArguments = argumentList;
|
||||
|
||||
executionContext.setRecurInitializing();
|
||||
|
||||
try {
|
||||
if (currentFunction.isArgumentListEvaluated())
|
||||
executionContext.setRecurInitializing();
|
||||
|
||||
if (isRecurArgumentListEvaluated())
|
||||
recurArguments = evalRecurArgumentList(argumentList);
|
||||
} finally {
|
||||
executionContext.clearRecurInitializing();
|
||||
}
|
||||
|
||||
executionContext.setRecur();
|
||||
|
||||
return recurArguments;
|
||||
}
|
||||
|
||||
private boolean isRecurArgumentListEvaluated() {
|
||||
return executionContext.getCurrentFunction().isArgumentListEvaluated();
|
||||
}
|
||||
|
||||
public static class RecurOutsideOfFunctionException extends LispException {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
@ -68,4 +75,14 @@ public class RECUR extends LispSpecialFunction {
|
|||
}
|
||||
}
|
||||
|
||||
public static class RecurNotInTailPositionException extends LispException {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@Override
|
||||
public String getMessage() {
|
||||
return "recur not in tail position";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package recursion.tail;
|
||||
package recursion;
|
||||
|
||||
import java.util.stream.Stream;
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package recursion.tail;
|
||||
package recursion;
|
||||
|
||||
public class TailCalls {
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
package sexpression;
|
||||
|
||||
import static recursion.tail.TailCalls.done;
|
||||
import static recursion.tail.TailCalls.tailCall;
|
||||
import static recursion.TailCalls.done;
|
||||
import static recursion.TailCalls.tailCall;
|
||||
|
||||
import recursion.tail.TailCall;
|
||||
import recursion.TailCall;
|
||||
|
||||
@DisplayName("list")
|
||||
public class Cons extends SExpression {
|
||||
|
|
|
@ -33,6 +33,11 @@ public class ExecutionContext {
|
|||
this.scope = scope;
|
||||
}
|
||||
|
||||
public void restoreGlobalScope() {
|
||||
while (!scope.isGlobal())
|
||||
scope = scope.getParent();
|
||||
}
|
||||
|
||||
public void clearContext() {
|
||||
this.scope = new SymbolTable();
|
||||
this.functionCalls = new Stack<>();
|
||||
|
|
|
@ -2,6 +2,7 @@ package function.builtin;
|
|||
|
||||
import static function.builtin.EVAL.lookupSymbol;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.fail;
|
||||
import static sexpression.Nil.NIL;
|
||||
import static testutil.TestUtilities.assertIsErrorWithMessage;
|
||||
import static testutil.TestUtilities.assertSExpressionsMatch;
|
||||
|
@ -18,6 +19,7 @@ import function.builtin.EVAL.UndefinedFunctionException;
|
|||
import function.builtin.EVAL.UndefinedSymbolException;
|
||||
import function.builtin.EVAL.UnmatchedAtSignException;
|
||||
import function.builtin.EVAL.UnmatchedCommaException;
|
||||
import function.builtin.special.RECUR.RecurNotInTailPositionException;
|
||||
import testutil.SymbolAndFunctionCleaner;
|
||||
|
||||
public class EVALTest extends SymbolAndFunctionCleaner {
|
||||
|
@ -178,4 +180,50 @@ public class EVALTest extends SymbolAndFunctionCleaner {
|
|||
assertSExpressionsMatch(parseString("((1 2 3))"), evaluateString(input));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void scopeRestoredAfterFailure_Let() {
|
||||
evaluateString("(setq n 100)");
|
||||
try {
|
||||
evaluateString("(let ((n 200)) (begin 1 2 3 y))");
|
||||
fail("expected exception");
|
||||
} catch (UndefinedSymbolException e) {}
|
||||
|
||||
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void scopeRestoredAfterFailure_Defun() {
|
||||
evaluateString("(setq n 100)");
|
||||
try {
|
||||
evaluateString("(defun test (n) (begin 1 2 3 y))");
|
||||
evaluateString("(test 200)");
|
||||
fail("expected exception");
|
||||
} catch (UndefinedSymbolException e) {}
|
||||
|
||||
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void scopeRestoredAfterFailure_Lambda() {
|
||||
evaluateString("(setq n 100)");
|
||||
try {
|
||||
evaluateString("((lambda (n) (begin 1 2 3 y)) 200)");
|
||||
fail("expected exception");
|
||||
} catch (UndefinedSymbolException e) {}
|
||||
|
||||
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void scopeRestoredAfterFailure_Recur() {
|
||||
evaluateString("(setq n 100)");
|
||||
try {
|
||||
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
|
||||
evaluateString("(tail-recursive 200)");
|
||||
fail("expected exception");
|
||||
} catch (RecurNotInTailPositionException e) {}
|
||||
|
||||
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -9,9 +9,10 @@ import org.junit.Test;
|
|||
|
||||
import function.ArgumentValidator.BadArgumentTypeException;
|
||||
import function.ArgumentValidator.TooManyArgumentsException;
|
||||
import function.builtin.EVAL.RecurNotInTailPositionException;
|
||||
import function.builtin.special.RECUR.NestedRecurException;
|
||||
import function.builtin.special.RECUR.RecurNotInTailPositionException;
|
||||
import function.builtin.special.RECUR.RecurOutsideOfFunctionException;
|
||||
import sexpression.SExpression;
|
||||
import testutil.SymbolAndFunctionCleaner;
|
||||
|
||||
public class RECURTest extends SymbolAndFunctionCleaner {
|
||||
|
@ -53,13 +54,13 @@ public class RECURTest extends SymbolAndFunctionCleaner {
|
|||
}
|
||||
|
||||
@Test(expected = RecurNotInTailPositionException.class)
|
||||
public void recurInNonTailPosition_ThrowsException() {
|
||||
public void recurInNonTailPositionInArgumentList() {
|
||||
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))");
|
||||
evaluateString("(tail-recursive 900)");
|
||||
}
|
||||
|
||||
@Test(expected = RecurNotInTailPositionException.class)
|
||||
public void recurInNonTailPosition2_ThrowsException() {
|
||||
public void recurInNonTailPositionInBegin() {
|
||||
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
|
||||
evaluateString("(tail-recursive 900)");
|
||||
}
|
||||
|
@ -76,6 +77,18 @@ public class RECURTest extends SymbolAndFunctionCleaner {
|
|||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recurInTailPositionWithApply() {
|
||||
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (apply 'recur (list (- n 1)))))");
|
||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recurInTailPositionWithFuncall() {
|
||||
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (call 'recur (- n 1))))");
|
||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recurCallsCurrentFunction_WithApply() {
|
||||
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
||||
|
@ -126,14 +139,24 @@ public class RECURTest extends SymbolAndFunctionCleaner {
|
|||
|
||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||
}
|
||||
// recur non-tail in apply call
|
||||
|
||||
// recur with no args, alters global variable
|
||||
@Test
|
||||
public void recurWithNoArgs_AltersGlobalVariable() {
|
||||
evaluateString("(defun tail-recursive () (if (= n 0) 'PASS (begin (setq n (- n 1)) (recur))))");
|
||||
evaluateString("(setq n 200)");
|
||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive)"));
|
||||
}
|
||||
|
||||
// recur with anonymous function
|
||||
@Test
|
||||
public void recurWithLambda() {
|
||||
SExpression lambdaTailCall = evaluateString("((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020)");
|
||||
assertSExpressionsMatch(parseString("PASS"), lambdaTailCall);
|
||||
}
|
||||
|
||||
// recur with nested anonymous function
|
||||
|
||||
// test scope after failure in function, is it global again?
|
||||
@Test
|
||||
public void recurWithNestedLambda() {
|
||||
evaluateString("(defun nested-tail () ((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020))");
|
||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(nested-tail)"));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
package table;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertThat;
|
||||
import static sexpression.Nil.NIL;
|
||||
import static sexpression.Symbol.T;
|
||||
|
||||
|
@ -85,4 +87,18 @@ public class ExecutionContextTest {
|
|||
assertEquals(NIL, executionContext.lookupSymbolValue("shadowed"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void restoreGlobalContext() {
|
||||
SymbolTable global = executionContext.getScope();
|
||||
SymbolTable scope1 = new SymbolTable(global);
|
||||
SymbolTable scope2 = new SymbolTable(scope1);
|
||||
SymbolTable scope3 = new SymbolTable(scope2);
|
||||
executionContext.setScope(scope3);
|
||||
|
||||
assertThat(executionContext.getScope().isGlobal(), is(false));
|
||||
executionContext.restoreGlobalScope();
|
||||
assertThat(executionContext.getScope().isGlobal(), is(true));
|
||||
assertThat(executionContext.getScope(), is(global));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue