Restore scope properly after errors
This commit is contained in:
parent
e2272fa976
commit
aeb3074750
|
@ -2,14 +2,14 @@ package function;
|
||||||
|
|
||||||
import static function.builtin.EVAL.eval;
|
import static function.builtin.EVAL.eval;
|
||||||
import static java.text.MessageFormat.format;
|
import static java.text.MessageFormat.format;
|
||||||
import static recursion.tail.TailCalls.done;
|
import static recursion.TailCalls.done;
|
||||||
import static sexpression.Nil.NIL;
|
import static sexpression.Nil.NIL;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
|
||||||
import error.LispException;
|
import error.LispException;
|
||||||
import recursion.tail.TailCall;
|
import recursion.TailCall;
|
||||||
import recursion.tail.TailCalls;
|
import recursion.TailCalls;
|
||||||
import sexpression.Cons;
|
import sexpression.Cons;
|
||||||
import sexpression.SExpression;
|
import sexpression.SExpression;
|
||||||
import sexpression.Symbol;
|
import sexpression.Symbol;
|
||||||
|
|
|
@ -12,6 +12,7 @@ import error.LispException;
|
||||||
import function.ArgumentValidator;
|
import function.ArgumentValidator;
|
||||||
import function.FunctionNames;
|
import function.FunctionNames;
|
||||||
import function.LispFunction;
|
import function.LispFunction;
|
||||||
|
import function.builtin.special.RECUR.RecurNotInTailPositionException;
|
||||||
import sexpression.BackquoteExpression;
|
import sexpression.BackquoteExpression;
|
||||||
import sexpression.Cons;
|
import sexpression.Cons;
|
||||||
import sexpression.LambdaExpression;
|
import sexpression.LambdaExpression;
|
||||||
|
@ -22,10 +23,17 @@ import table.ExecutionContext;
|
||||||
@FunctionNames({ "EVAL" })
|
@FunctionNames({ "EVAL" })
|
||||||
public class EVAL extends LispFunction {
|
public class EVAL extends LispFunction {
|
||||||
|
|
||||||
|
private static ExecutionContext executionContext = ExecutionContext.getInstance();
|
||||||
|
|
||||||
public static SExpression eval(SExpression sExpression) {
|
public static SExpression eval(SExpression sExpression) {
|
||||||
Cons argumentList = makeList(sExpression);
|
Cons argumentList = makeList(sExpression);
|
||||||
|
|
||||||
|
try {
|
||||||
return lookupFunction("EVAL").call(argumentList);
|
return lookupFunction("EVAL").call(argumentList);
|
||||||
|
} catch (LispException e) {
|
||||||
|
executionContext.restoreGlobalScope();
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Cons evalRecurArgumentList(Cons argumentList) {
|
public static Cons evalRecurArgumentList(Cons argumentList) {
|
||||||
|
@ -62,27 +70,28 @@ public class EVAL extends LispFunction {
|
||||||
}
|
}
|
||||||
|
|
||||||
private ArgumentValidator argumentValidator;
|
private ArgumentValidator argumentValidator;
|
||||||
private ExecutionContext executionContext;
|
|
||||||
|
|
||||||
public EVAL(String name) {
|
public EVAL(String name) {
|
||||||
this.argumentValidator = new ArgumentValidator(name);
|
this.argumentValidator = new ArgumentValidator(name);
|
||||||
this.argumentValidator.setExactNumberOfArguments(1);
|
this.argumentValidator.setExactNumberOfArguments(1);
|
||||||
this.executionContext = ExecutionContext.getInstance();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SExpression call(Cons argumentList) {
|
public SExpression call(Cons argumentList) {
|
||||||
if (executionContext.isRecur()) {
|
verifyNotRecurring();
|
||||||
executionContext.clearContext();
|
|
||||||
throw new RecurNotInTailPositionException();
|
|
||||||
}
|
|
||||||
|
|
||||||
argumentValidator.validate(argumentList);
|
argumentValidator.validate(argumentList);
|
||||||
SExpression argument = argumentList.getFirst();
|
SExpression argument = argumentList.getFirst();
|
||||||
|
|
||||||
return evaluateExpression(argument);
|
return evaluateExpression(argument);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void verifyNotRecurring() {
|
||||||
|
if (executionContext.isRecur()) {
|
||||||
|
executionContext.clearRecur();
|
||||||
|
throw new RecurNotInTailPositionException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private SExpression evaluateExpression(SExpression argument) {
|
private SExpression evaluateExpression(SExpression argument) {
|
||||||
if (argument.isList())
|
if (argument.isList())
|
||||||
return evaluateList(argument);
|
return evaluateList(argument);
|
||||||
|
@ -123,10 +132,7 @@ public class EVAL extends LispFunction {
|
||||||
if (function.isArgumentListEvaluated())
|
if (function.isArgumentListEvaluated())
|
||||||
argumentList = evaluateArgumentList(argumentList);
|
argumentList = evaluateArgumentList(argumentList);
|
||||||
|
|
||||||
if (executionContext.isRecur()) {
|
verifyNotRecurring();
|
||||||
executionContext.clearContext();
|
|
||||||
throw new RecurNotInTailPositionException();
|
|
||||||
}
|
|
||||||
|
|
||||||
SExpression result = function.call(argumentList);
|
SExpression result = function.call(argumentList);
|
||||||
|
|
||||||
|
@ -211,14 +217,4 @@ public class EVAL extends LispFunction {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class RecurNotInTailPositionException extends LispException {
|
|
||||||
|
|
||||||
private static final long serialVersionUID = 1L;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getMessage() {
|
|
||||||
return "recur not in tail position";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
package function.builtin.cons;
|
package function.builtin.cons;
|
||||||
|
|
||||||
import static function.builtin.cons.LIST.makeList;
|
import static function.builtin.cons.LIST.makeList;
|
||||||
import static recursion.tail.TailCalls.done;
|
import static recursion.TailCalls.done;
|
||||||
import static recursion.tail.TailCalls.tailCall;
|
import static recursion.TailCalls.tailCall;
|
||||||
|
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
|
|
||||||
import function.ArgumentValidator;
|
import function.ArgumentValidator;
|
||||||
import function.FunctionNames;
|
import function.FunctionNames;
|
||||||
import function.LispFunction;
|
import function.LispFunction;
|
||||||
import recursion.tail.TailCall;
|
import recursion.TailCall;
|
||||||
import sexpression.Cons;
|
import sexpression.Cons;
|
||||||
import sexpression.LispNumber;
|
import sexpression.LispNumber;
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import static function.builtin.EVAL.evalRecurArgumentList;
|
||||||
import error.LispException;
|
import error.LispException;
|
||||||
import function.ArgumentValidator;
|
import function.ArgumentValidator;
|
||||||
import function.FunctionNames;
|
import function.FunctionNames;
|
||||||
import function.LispFunction;
|
|
||||||
import function.LispSpecialFunction;
|
import function.LispSpecialFunction;
|
||||||
import sexpression.Cons;
|
import sexpression.Cons;
|
||||||
import sexpression.SExpression;
|
import sexpression.SExpression;
|
||||||
|
@ -31,23 +30,31 @@ public class RECUR extends LispSpecialFunction {
|
||||||
throw new RecurOutsideOfFunctionException();
|
throw new RecurOutsideOfFunctionException();
|
||||||
|
|
||||||
argumentValidator.validate(argumentList);
|
argumentValidator.validate(argumentList);
|
||||||
LispFunction currentFunction = executionContext.getCurrentFunction();
|
Cons recurArguments = getRecurArguments(argumentList);
|
||||||
|
executionContext.setRecur();
|
||||||
|
|
||||||
|
return recurArguments;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Cons getRecurArguments(Cons argumentList) {
|
||||||
Cons recurArguments = argumentList;
|
Cons recurArguments = argumentList;
|
||||||
|
|
||||||
|
try {
|
||||||
executionContext.setRecurInitializing();
|
executionContext.setRecurInitializing();
|
||||||
|
|
||||||
try {
|
if (isRecurArgumentListEvaluated())
|
||||||
if (currentFunction.isArgumentListEvaluated())
|
|
||||||
recurArguments = evalRecurArgumentList(argumentList);
|
recurArguments = evalRecurArgumentList(argumentList);
|
||||||
} finally {
|
} finally {
|
||||||
executionContext.clearRecurInitializing();
|
executionContext.clearRecurInitializing();
|
||||||
}
|
}
|
||||||
|
|
||||||
executionContext.setRecur();
|
|
||||||
|
|
||||||
return recurArguments;
|
return recurArguments;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean isRecurArgumentListEvaluated() {
|
||||||
|
return executionContext.getCurrentFunction().isArgumentListEvaluated();
|
||||||
|
}
|
||||||
|
|
||||||
public static class RecurOutsideOfFunctionException extends LispException {
|
public static class RecurOutsideOfFunctionException extends LispException {
|
||||||
|
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
|
@ -68,4 +75,14 @@ public class RECUR extends LispSpecialFunction {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static class RecurNotInTailPositionException extends LispException {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMessage() {
|
||||||
|
return "recur not in tail position";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package recursion.tail;
|
package recursion;
|
||||||
|
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package recursion.tail;
|
package recursion;
|
||||||
|
|
||||||
public class TailCalls {
|
public class TailCalls {
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
package sexpression;
|
package sexpression;
|
||||||
|
|
||||||
import static recursion.tail.TailCalls.done;
|
import static recursion.TailCalls.done;
|
||||||
import static recursion.tail.TailCalls.tailCall;
|
import static recursion.TailCalls.tailCall;
|
||||||
|
|
||||||
import recursion.tail.TailCall;
|
import recursion.TailCall;
|
||||||
|
|
||||||
@DisplayName("list")
|
@DisplayName("list")
|
||||||
public class Cons extends SExpression {
|
public class Cons extends SExpression {
|
||||||
|
|
|
@ -33,6 +33,11 @@ public class ExecutionContext {
|
||||||
this.scope = scope;
|
this.scope = scope;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void restoreGlobalScope() {
|
||||||
|
while (!scope.isGlobal())
|
||||||
|
scope = scope.getParent();
|
||||||
|
}
|
||||||
|
|
||||||
public void clearContext() {
|
public void clearContext() {
|
||||||
this.scope = new SymbolTable();
|
this.scope = new SymbolTable();
|
||||||
this.functionCalls = new Stack<>();
|
this.functionCalls = new Stack<>();
|
||||||
|
|
|
@ -2,6 +2,7 @@ package function.builtin;
|
||||||
|
|
||||||
import static function.builtin.EVAL.lookupSymbol;
|
import static function.builtin.EVAL.lookupSymbol;
|
||||||
import static org.junit.Assert.assertNull;
|
import static org.junit.Assert.assertNull;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
import static sexpression.Nil.NIL;
|
import static sexpression.Nil.NIL;
|
||||||
import static testutil.TestUtilities.assertIsErrorWithMessage;
|
import static testutil.TestUtilities.assertIsErrorWithMessage;
|
||||||
import static testutil.TestUtilities.assertSExpressionsMatch;
|
import static testutil.TestUtilities.assertSExpressionsMatch;
|
||||||
|
@ -18,6 +19,7 @@ import function.builtin.EVAL.UndefinedFunctionException;
|
||||||
import function.builtin.EVAL.UndefinedSymbolException;
|
import function.builtin.EVAL.UndefinedSymbolException;
|
||||||
import function.builtin.EVAL.UnmatchedAtSignException;
|
import function.builtin.EVAL.UnmatchedAtSignException;
|
||||||
import function.builtin.EVAL.UnmatchedCommaException;
|
import function.builtin.EVAL.UnmatchedCommaException;
|
||||||
|
import function.builtin.special.RECUR.RecurNotInTailPositionException;
|
||||||
import testutil.SymbolAndFunctionCleaner;
|
import testutil.SymbolAndFunctionCleaner;
|
||||||
|
|
||||||
public class EVALTest extends SymbolAndFunctionCleaner {
|
public class EVALTest extends SymbolAndFunctionCleaner {
|
||||||
|
@ -178,4 +180,50 @@ public class EVALTest extends SymbolAndFunctionCleaner {
|
||||||
assertSExpressionsMatch(parseString("((1 2 3))"), evaluateString(input));
|
assertSExpressionsMatch(parseString("((1 2 3))"), evaluateString(input));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void scopeRestoredAfterFailure_Let() {
|
||||||
|
evaluateString("(setq n 100)");
|
||||||
|
try {
|
||||||
|
evaluateString("(let ((n 200)) (begin 1 2 3 y))");
|
||||||
|
fail("expected exception");
|
||||||
|
} catch (UndefinedSymbolException e) {}
|
||||||
|
|
||||||
|
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void scopeRestoredAfterFailure_Defun() {
|
||||||
|
evaluateString("(setq n 100)");
|
||||||
|
try {
|
||||||
|
evaluateString("(defun test (n) (begin 1 2 3 y))");
|
||||||
|
evaluateString("(test 200)");
|
||||||
|
fail("expected exception");
|
||||||
|
} catch (UndefinedSymbolException e) {}
|
||||||
|
|
||||||
|
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void scopeRestoredAfterFailure_Lambda() {
|
||||||
|
evaluateString("(setq n 100)");
|
||||||
|
try {
|
||||||
|
evaluateString("((lambda (n) (begin 1 2 3 y)) 200)");
|
||||||
|
fail("expected exception");
|
||||||
|
} catch (UndefinedSymbolException e) {}
|
||||||
|
|
||||||
|
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void scopeRestoredAfterFailure_Recur() {
|
||||||
|
evaluateString("(setq n 100)");
|
||||||
|
try {
|
||||||
|
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
|
||||||
|
evaluateString("(tail-recursive 200)");
|
||||||
|
fail("expected exception");
|
||||||
|
} catch (RecurNotInTailPositionException e) {}
|
||||||
|
|
||||||
|
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,9 +9,10 @@ import org.junit.Test;
|
||||||
|
|
||||||
import function.ArgumentValidator.BadArgumentTypeException;
|
import function.ArgumentValidator.BadArgumentTypeException;
|
||||||
import function.ArgumentValidator.TooManyArgumentsException;
|
import function.ArgumentValidator.TooManyArgumentsException;
|
||||||
import function.builtin.EVAL.RecurNotInTailPositionException;
|
|
||||||
import function.builtin.special.RECUR.NestedRecurException;
|
import function.builtin.special.RECUR.NestedRecurException;
|
||||||
|
import function.builtin.special.RECUR.RecurNotInTailPositionException;
|
||||||
import function.builtin.special.RECUR.RecurOutsideOfFunctionException;
|
import function.builtin.special.RECUR.RecurOutsideOfFunctionException;
|
||||||
|
import sexpression.SExpression;
|
||||||
import testutil.SymbolAndFunctionCleaner;
|
import testutil.SymbolAndFunctionCleaner;
|
||||||
|
|
||||||
public class RECURTest extends SymbolAndFunctionCleaner {
|
public class RECURTest extends SymbolAndFunctionCleaner {
|
||||||
|
@ -53,13 +54,13 @@ public class RECURTest extends SymbolAndFunctionCleaner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = RecurNotInTailPositionException.class)
|
@Test(expected = RecurNotInTailPositionException.class)
|
||||||
public void recurInNonTailPosition_ThrowsException() {
|
public void recurInNonTailPositionInArgumentList() {
|
||||||
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))");
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))");
|
||||||
evaluateString("(tail-recursive 900)");
|
evaluateString("(tail-recursive 900)");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = RecurNotInTailPositionException.class)
|
@Test(expected = RecurNotInTailPositionException.class)
|
||||||
public void recurInNonTailPosition2_ThrowsException() {
|
public void recurInNonTailPositionInBegin() {
|
||||||
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
|
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
|
||||||
evaluateString("(tail-recursive 900)");
|
evaluateString("(tail-recursive 900)");
|
||||||
}
|
}
|
||||||
|
@ -76,6 +77,18 @@ public class RECURTest extends SymbolAndFunctionCleaner {
|
||||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recurInTailPositionWithApply() {
|
||||||
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (apply 'recur (list (- n 1)))))");
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recurInTailPositionWithFuncall() {
|
||||||
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (call 'recur (- n 1))))");
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void recurCallsCurrentFunction_WithApply() {
|
public void recurCallsCurrentFunction_WithApply() {
|
||||||
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
|
||||||
|
@ -126,14 +139,24 @@ public class RECURTest extends SymbolAndFunctionCleaner {
|
||||||
|
|
||||||
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
|
||||||
}
|
}
|
||||||
// recur non-tail in apply call
|
|
||||||
|
|
||||||
// recur with no args, alters global variable
|
@Test
|
||||||
|
public void recurWithNoArgs_AltersGlobalVariable() {
|
||||||
|
evaluateString("(defun tail-recursive () (if (= n 0) 'PASS (begin (setq n (- n 1)) (recur))))");
|
||||||
|
evaluateString("(setq n 200)");
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive)"));
|
||||||
|
}
|
||||||
|
|
||||||
// recur with anonymous function
|
@Test
|
||||||
|
public void recurWithLambda() {
|
||||||
|
SExpression lambdaTailCall = evaluateString("((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020)");
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), lambdaTailCall);
|
||||||
|
}
|
||||||
|
|
||||||
// recur with nested anonymous function
|
@Test
|
||||||
|
public void recurWithNestedLambda() {
|
||||||
// test scope after failure in function, is it global again?
|
evaluateString("(defun nested-tail () ((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020))");
|
||||||
|
assertSExpressionsMatch(parseString("PASS"), evaluateString("(nested-tail)"));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
package table;
|
package table;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotEquals;
|
import static org.junit.Assert.assertNotEquals;
|
||||||
import static org.junit.Assert.assertNull;
|
import static org.junit.Assert.assertNull;
|
||||||
|
import static org.junit.Assert.assertThat;
|
||||||
import static sexpression.Nil.NIL;
|
import static sexpression.Nil.NIL;
|
||||||
import static sexpression.Symbol.T;
|
import static sexpression.Symbol.T;
|
||||||
|
|
||||||
|
@ -85,4 +87,18 @@ public class ExecutionContextTest {
|
||||||
assertEquals(NIL, executionContext.lookupSymbolValue("shadowed"));
|
assertEquals(NIL, executionContext.lookupSymbolValue("shadowed"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void restoreGlobalContext() {
|
||||||
|
SymbolTable global = executionContext.getScope();
|
||||||
|
SymbolTable scope1 = new SymbolTable(global);
|
||||||
|
SymbolTable scope2 = new SymbolTable(scope1);
|
||||||
|
SymbolTable scope3 = new SymbolTable(scope2);
|
||||||
|
executionContext.setScope(scope3);
|
||||||
|
|
||||||
|
assertThat(executionContext.getScope().isGlobal(), is(false));
|
||||||
|
executionContext.restoreGlobalScope();
|
||||||
|
assertThat(executionContext.getScope().isGlobal(), is(true));
|
||||||
|
assertThat(executionContext.getScope(), is(global));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue