Restore scope properly after errors

This commit is contained in:
Mike Cifelli 2017-11-18 09:24:45 -05:00
parent e2272fa976
commit aeb3074750
11 changed files with 154 additions and 49 deletions

View File

@ -2,14 +2,14 @@ package function;
import static function.builtin.EVAL.eval; import static function.builtin.EVAL.eval;
import static java.text.MessageFormat.format; import static java.text.MessageFormat.format;
import static recursion.tail.TailCalls.done; import static recursion.TailCalls.done;
import static sexpression.Nil.NIL; import static sexpression.Nil.NIL;
import java.util.ArrayList; import java.util.ArrayList;
import error.LispException; import error.LispException;
import recursion.tail.TailCall; import recursion.TailCall;
import recursion.tail.TailCalls; import recursion.TailCalls;
import sexpression.Cons; import sexpression.Cons;
import sexpression.SExpression; import sexpression.SExpression;
import sexpression.Symbol; import sexpression.Symbol;

View File

@ -12,6 +12,7 @@ import error.LispException;
import function.ArgumentValidator; import function.ArgumentValidator;
import function.FunctionNames; import function.FunctionNames;
import function.LispFunction; import function.LispFunction;
import function.builtin.special.RECUR.RecurNotInTailPositionException;
import sexpression.BackquoteExpression; import sexpression.BackquoteExpression;
import sexpression.Cons; import sexpression.Cons;
import sexpression.LambdaExpression; import sexpression.LambdaExpression;
@ -22,10 +23,17 @@ import table.ExecutionContext;
@FunctionNames({ "EVAL" }) @FunctionNames({ "EVAL" })
public class EVAL extends LispFunction { public class EVAL extends LispFunction {
private static ExecutionContext executionContext = ExecutionContext.getInstance();
public static SExpression eval(SExpression sExpression) { public static SExpression eval(SExpression sExpression) {
Cons argumentList = makeList(sExpression); Cons argumentList = makeList(sExpression);
return lookupFunction("EVAL").call(argumentList); try {
return lookupFunction("EVAL").call(argumentList);
} catch (LispException e) {
executionContext.restoreGlobalScope();
throw e;
}
} }
public static Cons evalRecurArgumentList(Cons argumentList) { public static Cons evalRecurArgumentList(Cons argumentList) {
@ -62,27 +70,28 @@ public class EVAL extends LispFunction {
} }
private ArgumentValidator argumentValidator; private ArgumentValidator argumentValidator;
private ExecutionContext executionContext;
public EVAL(String name) { public EVAL(String name) {
this.argumentValidator = new ArgumentValidator(name); this.argumentValidator = new ArgumentValidator(name);
this.argumentValidator.setExactNumberOfArguments(1); this.argumentValidator.setExactNumberOfArguments(1);
this.executionContext = ExecutionContext.getInstance();
} }
@Override @Override
public SExpression call(Cons argumentList) { public SExpression call(Cons argumentList) {
if (executionContext.isRecur()) { verifyNotRecurring();
executionContext.clearContext();
throw new RecurNotInTailPositionException();
}
argumentValidator.validate(argumentList); argumentValidator.validate(argumentList);
SExpression argument = argumentList.getFirst(); SExpression argument = argumentList.getFirst();
return evaluateExpression(argument); return evaluateExpression(argument);
} }
private void verifyNotRecurring() {
if (executionContext.isRecur()) {
executionContext.clearRecur();
throw new RecurNotInTailPositionException();
}
}
private SExpression evaluateExpression(SExpression argument) { private SExpression evaluateExpression(SExpression argument) {
if (argument.isList()) if (argument.isList())
return evaluateList(argument); return evaluateList(argument);
@ -123,10 +132,7 @@ public class EVAL extends LispFunction {
if (function.isArgumentListEvaluated()) if (function.isArgumentListEvaluated())
argumentList = evaluateArgumentList(argumentList); argumentList = evaluateArgumentList(argumentList);
if (executionContext.isRecur()) { verifyNotRecurring();
executionContext.clearContext();
throw new RecurNotInTailPositionException();
}
SExpression result = function.call(argumentList); SExpression result = function.call(argumentList);
@ -211,14 +217,4 @@ public class EVAL extends LispFunction {
} }
} }
public static class RecurNotInTailPositionException extends LispException {
private static final long serialVersionUID = 1L;
@Override
public String getMessage() {
return "recur not in tail position";
}
}
} }

View File

@ -1,15 +1,15 @@
package function.builtin.cons; package function.builtin.cons;
import static function.builtin.cons.LIST.makeList; import static function.builtin.cons.LIST.makeList;
import static recursion.tail.TailCalls.done; import static recursion.TailCalls.done;
import static recursion.tail.TailCalls.tailCall; import static recursion.TailCalls.tailCall;
import java.math.BigInteger; import java.math.BigInteger;
import function.ArgumentValidator; import function.ArgumentValidator;
import function.FunctionNames; import function.FunctionNames;
import function.LispFunction; import function.LispFunction;
import recursion.tail.TailCall; import recursion.TailCall;
import sexpression.Cons; import sexpression.Cons;
import sexpression.LispNumber; import sexpression.LispNumber;

View File

@ -5,7 +5,6 @@ import static function.builtin.EVAL.evalRecurArgumentList;
import error.LispException; import error.LispException;
import function.ArgumentValidator; import function.ArgumentValidator;
import function.FunctionNames; import function.FunctionNames;
import function.LispFunction;
import function.LispSpecialFunction; import function.LispSpecialFunction;
import sexpression.Cons; import sexpression.Cons;
import sexpression.SExpression; import sexpression.SExpression;
@ -31,23 +30,31 @@ public class RECUR extends LispSpecialFunction {
throw new RecurOutsideOfFunctionException(); throw new RecurOutsideOfFunctionException();
argumentValidator.validate(argumentList); argumentValidator.validate(argumentList);
LispFunction currentFunction = executionContext.getCurrentFunction(); Cons recurArguments = getRecurArguments(argumentList);
executionContext.setRecur();
return recurArguments;
}
private Cons getRecurArguments(Cons argumentList) {
Cons recurArguments = argumentList; Cons recurArguments = argumentList;
executionContext.setRecurInitializing();
try { try {
if (currentFunction.isArgumentListEvaluated()) executionContext.setRecurInitializing();
if (isRecurArgumentListEvaluated())
recurArguments = evalRecurArgumentList(argumentList); recurArguments = evalRecurArgumentList(argumentList);
} finally { } finally {
executionContext.clearRecurInitializing(); executionContext.clearRecurInitializing();
} }
executionContext.setRecur();
return recurArguments; return recurArguments;
} }
private boolean isRecurArgumentListEvaluated() {
return executionContext.getCurrentFunction().isArgumentListEvaluated();
}
public static class RecurOutsideOfFunctionException extends LispException { public static class RecurOutsideOfFunctionException extends LispException {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
@ -68,4 +75,14 @@ public class RECUR extends LispSpecialFunction {
} }
} }
public static class RecurNotInTailPositionException extends LispException {
private static final long serialVersionUID = 1L;
@Override
public String getMessage() {
return "recur not in tail position";
}
}
} }

View File

@ -1,4 +1,4 @@
package recursion.tail; package recursion;
import java.util.stream.Stream; import java.util.stream.Stream;

View File

@ -1,4 +1,4 @@
package recursion.tail; package recursion;
public class TailCalls { public class TailCalls {

View File

@ -1,9 +1,9 @@
package sexpression; package sexpression;
import static recursion.tail.TailCalls.done; import static recursion.TailCalls.done;
import static recursion.tail.TailCalls.tailCall; import static recursion.TailCalls.tailCall;
import recursion.tail.TailCall; import recursion.TailCall;
@DisplayName("list") @DisplayName("list")
public class Cons extends SExpression { public class Cons extends SExpression {

View File

@ -33,6 +33,11 @@ public class ExecutionContext {
this.scope = scope; this.scope = scope;
} }
public void restoreGlobalScope() {
while (!scope.isGlobal())
scope = scope.getParent();
}
public void clearContext() { public void clearContext() {
this.scope = new SymbolTable(); this.scope = new SymbolTable();
this.functionCalls = new Stack<>(); this.functionCalls = new Stack<>();

View File

@ -2,6 +2,7 @@ package function.builtin;
import static function.builtin.EVAL.lookupSymbol; import static function.builtin.EVAL.lookupSymbol;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static sexpression.Nil.NIL; import static sexpression.Nil.NIL;
import static testutil.TestUtilities.assertIsErrorWithMessage; import static testutil.TestUtilities.assertIsErrorWithMessage;
import static testutil.TestUtilities.assertSExpressionsMatch; import static testutil.TestUtilities.assertSExpressionsMatch;
@ -18,6 +19,7 @@ import function.builtin.EVAL.UndefinedFunctionException;
import function.builtin.EVAL.UndefinedSymbolException; import function.builtin.EVAL.UndefinedSymbolException;
import function.builtin.EVAL.UnmatchedAtSignException; import function.builtin.EVAL.UnmatchedAtSignException;
import function.builtin.EVAL.UnmatchedCommaException; import function.builtin.EVAL.UnmatchedCommaException;
import function.builtin.special.RECUR.RecurNotInTailPositionException;
import testutil.SymbolAndFunctionCleaner; import testutil.SymbolAndFunctionCleaner;
public class EVALTest extends SymbolAndFunctionCleaner { public class EVALTest extends SymbolAndFunctionCleaner {
@ -178,4 +180,50 @@ public class EVALTest extends SymbolAndFunctionCleaner {
assertSExpressionsMatch(parseString("((1 2 3))"), evaluateString(input)); assertSExpressionsMatch(parseString("((1 2 3))"), evaluateString(input));
} }
@Test
public void scopeRestoredAfterFailure_Let() {
evaluateString("(setq n 100)");
try {
evaluateString("(let ((n 200)) (begin 1 2 3 y))");
fail("expected exception");
} catch (UndefinedSymbolException e) {}
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
}
@Test
public void scopeRestoredAfterFailure_Defun() {
evaluateString("(setq n 100)");
try {
evaluateString("(defun test (n) (begin 1 2 3 y))");
evaluateString("(test 200)");
fail("expected exception");
} catch (UndefinedSymbolException e) {}
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
}
@Test
public void scopeRestoredAfterFailure_Lambda() {
evaluateString("(setq n 100)");
try {
evaluateString("((lambda (n) (begin 1 2 3 y)) 200)");
fail("expected exception");
} catch (UndefinedSymbolException e) {}
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
}
@Test
public void scopeRestoredAfterFailure_Recur() {
evaluateString("(setq n 100)");
try {
evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
evaluateString("(tail-recursive 200)");
fail("expected exception");
} catch (RecurNotInTailPositionException e) {}
assertSExpressionsMatch(parseString("100"), evaluateString("n"));
}
} }

View File

@ -9,9 +9,10 @@ import org.junit.Test;
import function.ArgumentValidator.BadArgumentTypeException; import function.ArgumentValidator.BadArgumentTypeException;
import function.ArgumentValidator.TooManyArgumentsException; import function.ArgumentValidator.TooManyArgumentsException;
import function.builtin.EVAL.RecurNotInTailPositionException;
import function.builtin.special.RECUR.NestedRecurException; import function.builtin.special.RECUR.NestedRecurException;
import function.builtin.special.RECUR.RecurNotInTailPositionException;
import function.builtin.special.RECUR.RecurOutsideOfFunctionException; import function.builtin.special.RECUR.RecurOutsideOfFunctionException;
import sexpression.SExpression;
import testutil.SymbolAndFunctionCleaner; import testutil.SymbolAndFunctionCleaner;
public class RECURTest extends SymbolAndFunctionCleaner { public class RECURTest extends SymbolAndFunctionCleaner {
@ -53,13 +54,13 @@ public class RECURTest extends SymbolAndFunctionCleaner {
} }
@Test(expected = RecurNotInTailPositionException.class) @Test(expected = RecurNotInTailPositionException.class)
public void recurInNonTailPosition_ThrowsException() { public void recurInNonTailPositionInArgumentList() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))"); evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (list (recur (- n 1)))))");
evaluateString("(tail-recursive 900)"); evaluateString("(tail-recursive 900)");
} }
@Test(expected = RecurNotInTailPositionException.class) @Test(expected = RecurNotInTailPositionException.class)
public void recurInNonTailPosition2_ThrowsException() { public void recurInNonTailPositionInBegin() {
evaluateString("(defun tail-recursive (n) (begin (recur) 2))"); evaluateString("(defun tail-recursive (n) (begin (recur) 2))");
evaluateString("(tail-recursive 900)"); evaluateString("(tail-recursive 900)");
} }
@ -76,6 +77,18 @@ public class RECURTest extends SymbolAndFunctionCleaner {
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
} }
@Test
public void recurInTailPositionWithApply() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (apply 'recur (list (- n 1)))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test
public void recurInTailPositionWithFuncall() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (call 'recur (- n 1))))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
}
@Test @Test
public void recurCallsCurrentFunction_WithApply() { public void recurCallsCurrentFunction_WithApply() {
evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))");
@ -126,14 +139,24 @@ public class RECURTest extends SymbolAndFunctionCleaner {
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)"));
} }
// recur non-tail in apply call
// recur with no args, alters global variable @Test
public void recurWithNoArgs_AltersGlobalVariable() {
evaluateString("(defun tail-recursive () (if (= n 0) 'PASS (begin (setq n (- n 1)) (recur))))");
evaluateString("(setq n 200)");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive)"));
}
// recur with anonymous function @Test
public void recurWithLambda() {
SExpression lambdaTailCall = evaluateString("((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020)");
assertSExpressionsMatch(parseString("PASS"), lambdaTailCall);
}
// recur with nested anonymous function @Test
public void recurWithNestedLambda() {
// test scope after failure in function, is it global again? evaluateString("(defun nested-tail () ((lambda (n) (if (= n 0) 'PASS (recur (- n 1)))) 2020))");
assertSExpressionsMatch(parseString("PASS"), evaluateString("(nested-tail)"));
}
} }

View File

@ -1,8 +1,10 @@
package table; package table;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static sexpression.Nil.NIL; import static sexpression.Nil.NIL;
import static sexpression.Symbol.T; import static sexpression.Symbol.T;
@ -85,4 +87,18 @@ public class ExecutionContextTest {
assertEquals(NIL, executionContext.lookupSymbolValue("shadowed")); assertEquals(NIL, executionContext.lookupSymbolValue("shadowed"));
} }
@Test
public void restoreGlobalContext() {
SymbolTable global = executionContext.getScope();
SymbolTable scope1 = new SymbolTable(global);
SymbolTable scope2 = new SymbolTable(scope1);
SymbolTable scope3 = new SymbolTable(scope2);
executionContext.setScope(scope3);
assertThat(executionContext.getScope().isGlobal(), is(false));
executionContext.restoreGlobalScope();
assertThat(executionContext.getScope().isGlobal(), is(true));
assertThat(executionContext.getScope(), is(global));
}
} }