Restore scope properly after errors

This commit is contained in:
Mike Cifelli 2017-11-18 09:24:45 -05:00
parent e2272fa976
commit aeb3074750
11 changed files with 154 additions and 49 deletions

View File

@ -2,14 +2,14 @@ package function;
import static function.builtin.EVAL.eval;
import static java.text.MessageFormat.format;
import static recursion.tail.TailCalls.done;
import static recursion.TailCalls.done;
import static sexpression.Nil.NIL;
import java.util.ArrayList;
import error.LispException;
import recursion.tail.TailCall;
import recursion.tail.TailCalls;
import recursion.TailCall;
import recursion.TailCalls;
import sexpression.Cons;
import sexpression.SExpression;
import sexpression.Symbol;

View File

@ -12,6 +12,7 @@ import error.LispException;
import function.ArgumentValidator;
import function.FunctionNames;
import function.LispFunction;
import function.builtin.special.RECUR.RecurNotInTailPositionException;
import sexpression.BackquoteExpression;
import sexpression.Cons;
import sexpression.LambdaExpression;
@ -22,10 +23,17 @@ import table.ExecutionContext;
@FunctionNames({ "EVAL" })
public class EVAL extends LispFunction {
private static ExecutionContext executionContext = ExecutionContext.getInstance();
public static SExpression eval(SExpression sExpression) {
Cons argumentList = makeList(sExpression);
return lookupFunction("EVAL").call(argumentList);
try {
return lookupFunction("EVAL").call(argumentList);
} catch (LispException e) {
executionContext.restoreGlobalScope();
throw e;
}
}
public static Cons evalRecurArgumentList(Cons argumentList) {
@ -62,27 +70,28 @@ public class EVAL extends LispFunction {
}
private ArgumentValidator argumentValidator;
private ExecutionContext executionContext;
public EVAL(String name) {
this.argumentValidator = new ArgumentValidator(name);
this.argumentValidator.setExactNumberOfArguments(1);
this.executionContext = ExecutionContext.getInstance();
}
@Override
public SExpression call(Cons argumentList) {
if (executionContext.isRecur()) {
executionContext.clearContext();
throw new RecurNotInTailPositionException();
}
verifyNotRecurring();
argumentValidator.validate(argumentList);
SExpression argument = argumentList.getFirst();
return evaluateExpression(argument);
}
private void verifyNotRecurring() {
if (executionContext.isRecur()) {
executionContext.clearRecur();
throw new RecurNotInTailPositionException();
}
}
private SExpression evaluateExpression(SExpression argument) {
if (argument.isList())
return evaluateList(argument);
@ -123,10 +132,7 @@ public class EVAL extends LispFunction {
if (function.isArgumentListEvaluated())
argumentList = evaluateArgumentList(argumentList);
if (executionContext.isRecur()) {
executionContext.clearContext();
throw new RecurNotInTailPositionException();
}
verifyNotRecurring();
SExpression result = function.call(argumentList);
@ -211,14 +217,4 @@ public class EVAL extends LispFunction {
}
}
public static class RecurNotInTailPositionException extends LispException {
private static final long serialVersionUID = 1L;
@Override
public String getMessage() {
return "recur not in tail position";
}
}
}

View File

@ -1,15 +1,15 @@
package function.builtin.cons;
import static function.builtin.cons.LIST.makeList;
import static recursion.tail.TailCalls.done;
import static recursion.tail.TailCalls.tailCall;
import static recursion.TailCalls.done;
import static recursion.TailCalls.tailCall;
import java.math.BigInteger;
import function.ArgumentValidator;
import function.FunctionNames;
import function.LispFunction;
import recursion.tail.TailCall;
import recursion.TailCall;
import sexpression.Cons;
import sexpression.LispNumber;

View File

@ -5,7 +5,6 @@ import static function.builtin.EVAL.evalRecurArgumentList;
import error.LispException;
import function.ArgumentValidator;
import function.FunctionNames;
import function.LispFunction;
import function.LispSpecialFunction;
import sexpression.Cons;
import sexpression.SExpression;
@ -31,23 +30,31 @@ public class RECUR extends LispSpecialFunction {
throw new RecurOutsideOfFunctionException();
argumentValidator.validate(argumentList);
LispFunction currentFunction = executionContext.getCurrentFunction();
Cons recurArguments = getRecurArguments(argumentList);
executionContext.setRecur();
return recurArguments;
}
private Cons getRecurArguments(Cons argumentList) {
Cons recurArguments = argumentList;
executionContext.setRecurInitializing();
try {
if (currentFunction.isArgumentListEvaluated())
executionContext.setRecurInitializing();
if (isRecurArgumentListEvaluated())
recurArguments = evalRecurArgumentList(argumentList);
} finally {
executionContext.clearRecurInitializing();
}
executionContext.setRecur();
return recurArguments;
}
private boolean isRecurArgumentListEvaluated() {
return executionContext.getCurrentFunction().isArgumentListEvaluated();
}
public static class RecurOutsideOfFunctionException extends LispException {
private static final long serialVersionUID = 1L;
@ -68,4 +75,14 @@ public class RECUR extends LispSpecialFunction {
}
}
public static class RecurNotInTailPositionException extends LispException {
private static final long serialVersionUID = 1L;
@Override
public String getMessage() {
return "recur not in tail position";
}
}
}

View File

@ -1,4 +1,4 @@
package recursion.tail;
package recursion;
import java.util.stream.Stream;

View File

@ -1,4 +1,4 @@
package recursion.tail;
package recursion;
public class TailCalls {

View File

@ -1,9 +1,9 @@
package sexpression;
import static recursion.tail.TailCalls.done;
import static recursion.tail.TailCalls.tailCall;
import static recursion.TailCalls.done;
import static recursion.TailCalls.tailCall;
import recursion.tail.TailCall;
import recursion.TailCall;
@DisplayName("list")
public class Cons extends SExpression {

View File

@ -33,6 +33,11 @@ public class ExecutionContext {
this.scope = scope;
}
public void restoreGlobalScope() {
while (!scope.isGlobal())
scope = scope.getParent();
}
public void clearContext() {
this.scope = new SymbolTable();
this.functionCalls = new Stack<>();

View File

@ -2,6 +2,7 @@ package function.builtin;
import static function.builtin.EVAL.lookupSymbol;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static sexpression.Nil.NIL;
import static testutil.TestUtilities.assertIsErrorWithMessage;
import static testutil.TestUtilities.assertSExpressionsMatch;
@ -18,6 +19,7 @@ import function.builtin.EVAL.UndefinedFunctionException;
import function.builtin.EVAL.UndefinedSymbolException;
import function.builtin.EVAL.UnmatchedAtSignException;
import function.builtin.EVAL.UnmatchedCommaException;
import function.builtin.special.RECUR.RecurNotInTailPositionException;
import testutil.SymbolAndFunctionCleaner;
public class EVALTest extends SymbolAndFunctionCleaner {
@ -178,4 +180,50 @@ public class EVALTest extends SymbolAndFunctionCleaner {
assertSExpressionsMatch(parseString("((1 2 3))"), evaluateString(input));
}
@Test
public void scopeRestoredAfterFailure_Let() {
evaluateString("(setq n 100)");
try {
evaluateString("(let ((n 200)) (begin 1 2 3 y))");
fail("expected exception");
} catch (UndefinedSymbolException e) {}
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
}
@Test
public void scopeRestoredAfterFailure_Defun() {
evaluateString("(setq n 100)");
try {
evaluateString("(defun test (n) (begin 1 2 3 y))");
evaluateString("(test 200)");
fail("expected exception");
} catch (UndefinedSymbolException e) {}
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
}
@Test
public void scopeRestoredAfterFailure_Lambda() {
evaluateString("(setq n 100)");
try {
evaluateString("((lambda (n) (begin 1 2 3 y)) 200)");
fail("expected exception");
} catch (UndefinedSymbolException e) {}
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
}
@Test
public void scopeRestoredAfterFailure_Recur() {
evaluateString("(setq n 100)");
try {
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
evaluateString("(tail-recursive 200)");
fail("expected exception");
} catch (RecurNotInTailPositionException e) {}
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
}
}

View File

@ -9,9 +9,10 @@ import org.junit.Test;
import function.ArgumentValidator.BadArgumentTypeException;
import function.ArgumentValidator.TooManyArgumentsException;
import function.builtin.EVAL.RecurNotInTailPositionException;
import function.builtin.special.RECUR.NestedRecurException;
import function.builtin.special.RECUR.RecurNotInTailPositionException;
import function.builtin.special.RECUR.RecurOutsideOfFunctionException;
import sexpression.SExpression;
import testutil.SymbolAndFunctionCleaner;
public class RECURTest extends SymbolAndFunctionCleaner {
@ -53,13 +54,13 @@ public class RECURTest extends SymbolAndFunctionCleaner {
}
@Test(expected = RecurNotInTailPositionException.class)
public void recurInNonTailPosition_ThrowsException() {
public void recurInNonTailPositionInArgumentList() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))");
evaluateString("(tail-recursive 900)");
}
@Test(expected = RecurNotInTailPositionException.class)
public void recurInNonTailPosition2_ThrowsException() {
public void recurInNonTailPositionInBegin() {
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
evaluateString("(tail-recursive 900)");
}
@ -76,6 +77,18 @@ public class RECURTest extends SymbolAndFunctionCleaner {
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test
public void recurInTailPositionWithApply() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (apply 'recur (list (- n 1)))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test
public void recurInTailPositionWithFuncall() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (call 'recur (- n 1))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test
public void recurCallsCurrentFunction_WithApply() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
@ -126,14 +139,24 @@ public class RECURTest extends SymbolAndFunctionCleaner {
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
// recur non-tail in apply call
// recur with no args, alters global variable
@Test
public void recurWithNoArgs_AltersGlobalVariable() {
evaluateString("(defun tail-recursive () (if (= n 0) 'PASS (begin (setq n (- n 1)) (recur))))");
evaluateString("(setq n 200)");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive)"));
}
// recur with anonymous function
@Test
public void recurWithLambda() {
SExpression lambdaTailCall = evaluateString("((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020)");
assertSExpressionsMatch(parseString("PASS"), lambdaTailCall);
}
// recur with nested anonymous function
// test scope after failure in function, is it global again?
@Test
public void recurWithNestedLambda() {
evaluateString("(defun nested-tail () ((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(nested-tail)"));
}
}

View File

@ -1,8 +1,10 @@
package table;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static sexpression.Nil.NIL;
import static sexpression.Symbol.T;
@ -85,4 +87,18 @@ public class ExecutionContextTest {
assertEquals(NIL, executionContext.lookupSymbolValue("shadowed"));
}
@Test
public void restoreGlobalContext() {
SymbolTable global = executionContext.getScope();
SymbolTable scope1 = new SymbolTable(global);
SymbolTable scope2 = new SymbolTable(scope1);
SymbolTable scope3 = new SymbolTable(scope2);
executionContext.setScope(scope3);
assertThat(executionContext.getScope().isGlobal(), is(false));
executionContext.restoreGlobalScope();
assertThat(executionContext.getScope().isGlobal(), is(true));
assertThat(executionContext.getScope(), is(global));
}
}