From b4229c6ac1a70df3aa0150a006498b81d3756a39 Mon Sep 17 00:00:00 2001 From: Mike Cifelli Date: Thu, 16 Nov 2017 20:53:55 -0500 Subject: [PATCH] Implement RECUR --- src/function/UserDefinedFunction.java | 47 ++++++++++++++--- src/function/builtin/EVAL.java | 13 +++++ src/function/builtin/SET.java | 6 +-- src/function/builtin/special/RECUR.java | 54 ++++++++++++++++++++ src/table/ExecutionContext.java | 37 ++++++++++++++ src/table/FunctionTable.java | 2 + src/table/SymbolTable.java | 4 ++ test/function/builtin/special/RECURTest.java | 39 ++++++++++++++ 8 files changed, 190 insertions(+), 12 deletions(-) create mode 100644 src/function/builtin/special/RECUR.java create mode 100644 test/function/builtin/special/RECURTest.java diff --git a/src/function/UserDefinedFunction.java b/src/function/UserDefinedFunction.java index 11fd605..dee63ea 100644 --- a/src/function/UserDefinedFunction.java +++ b/src/function/UserDefinedFunction.java @@ -2,11 +2,14 @@ package function; import static function.builtin.EVAL.eval; import static java.text.MessageFormat.format; +import static recursion.tail.TailCalls.done; import static sexpression.Nil.NIL; import java.util.ArrayList; import error.LispException; +import recursion.tail.TailCall; +import recursion.tail.TailCalls; import sexpression.Cons; import sexpression.SExpression; import sexpression.Symbol; @@ -25,7 +28,7 @@ public class UserDefinedFunction extends LispFunction { private ArrayList formalParameters; private ArgumentValidator argumentValidator; private String keywordRestParameter; - private boolean isKeywordRest; + private boolean isKeywordRestPresent; public UserDefinedFunction(String name, Cons lambdaList, Cons body) { this.name = name; @@ -34,7 +37,7 @@ public class UserDefinedFunction extends LispFunction { this.executionContext = ExecutionContext.getInstance(); this.functionScope = executionContext.getScope(); this.keywordRestParameter = null; - this.isKeywordRest = false; + this.isKeywordRestPresent = false; createFormalParameters(lambdaList); setupArgumentValidator(); @@ -60,7 +63,7 @@ public class UserDefinedFunction extends LispFunction { } private Cons extractKeywordRestParameter(Cons lambdaList) { - isKeywordRest = true; + isKeywordRestPresent = true; lambdaList = advanceCons(lambdaList); keywordRestParameter = lambdaList.getFirst().toString(); lambdaList = advanceCons(lambdaList); @@ -78,7 +81,7 @@ public class UserDefinedFunction extends LispFunction { private void setupArgumentValidator() { argumentValidator = new ArgumentValidator(this.name); - if (isKeywordRest) + if (isKeywordRestPresent) argumentValidator.setMinimumNumberOfArguments(this.formalParameters.size()); else argumentValidator.setExactNumberOfArguments(this.formalParameters.size()); @@ -86,9 +89,25 @@ public class UserDefinedFunction extends LispFunction { @Override public SExpression call(Cons argumentList) { + return callTailRecursive(argumentList).invoke(); + } + + private TailCall callTailRecursive(Cons argumentList) { argumentValidator.validate(argumentList); - return evaluateInFunctionScope(argumentList); + // if recur + // clear recur indicator + // validate function list + // do not evaluate arguments - will have been done already if necessary + // tailCall(evaluatedArguments) + SExpression result = evaluateInFunctionScope(argumentList); + + if (executionContext.isRecur()) { + executionContext.clearRecur(); + return TailCalls.tailCall(() -> callTailRecursive((Cons) result)); + } + + return done(result); } private SExpression evaluateInFunctionScope(Cons argumentList) { @@ -111,7 +130,7 @@ public class UserDefinedFunction extends LispFunction { argumentList = (Cons) argumentList.getRest(); } - if (isKeywordRest) + if (isKeywordRestPresent) executionScope.put(keywordRestParameter, argumentList); return executionScope; @@ -120,7 +139,21 @@ public class UserDefinedFunction extends LispFunction { private SExpression evaluateBody() { SExpression lastEvaluation = NIL; - for (Cons expression = body; expression.isCons(); expression = (Cons) expression.getRest()) + // recur sets indicator in execution context + // eval can't be called when recur indicator is set - or error occurs (and recur indicator + // is cleared) + // recur checks to see if in function call, if not an error occurs (stack of objects + // (function call types) in execution context) + // if not, set recur indicator somewhere, (can't be in symbol table because LET creates + // scopes as well) + // + // on exception - clear function call counter and recur indicator in main loop + // recur evaluates its arguments (or not, according to function type) in its scope, then + // returns the list + // + // eval adds function call type to execution context stack of function calls when user + // defined function encountered + for (Cons expression = body; expression.isCons() /* and not recur indicator */; expression = (Cons) expression.getRest()) lastEvaluation = eval(expression.getFirst()); return lastEvaluation; diff --git a/src/function/builtin/EVAL.java b/src/function/builtin/EVAL.java index d5ec560..2d6490d 100644 --- a/src/function/builtin/EVAL.java +++ b/src/function/builtin/EVAL.java @@ -12,6 +12,7 @@ import error.LispException; import function.ArgumentValidator; import function.FunctionNames; import function.LispFunction; +import function.UserDefinedFunction; import sexpression.BackquoteExpression; import sexpression.Cons; import sexpression.LambdaExpression; @@ -28,6 +29,10 @@ public class EVAL extends LispFunction { return lookupFunction("EVAL").call(argumentList); } + public static Cons evalArgumentList(Cons argumentList) { + return ((EVAL) lookupFunction("EVAL")).evaluateArgumentList(argumentList); + } + public static LispFunction lookupFunctionOrLambda(SExpression functionExpression) { LispFunction function = lookupFunction(functionExpression.toString()); @@ -58,10 +63,12 @@ public class EVAL extends LispFunction { } private ArgumentValidator argumentValidator; + private ExecutionContext executionContext; public EVAL(String name) { this.argumentValidator = new ArgumentValidator(name); this.argumentValidator.setExactNumberOfArguments(1); + this.executionContext = ExecutionContext.getInstance(); } @Override @@ -109,6 +116,9 @@ public class EVAL extends LispFunction { } private SExpression callFunction(LispFunction function, Cons argumentList) { + if (function instanceof UserDefinedFunction) + executionContext.pushFunctionCall(function); + if (function.isArgumentListEvaluated()) argumentList = evaluateArgumentList(argumentList); @@ -117,6 +127,9 @@ public class EVAL extends LispFunction { if (function.isMacro()) result = eval(result); + if (function instanceof UserDefinedFunction) + executionContext.popFunctionCall(); + return result; } diff --git a/src/function/builtin/SET.java b/src/function/builtin/SET.java index 4101945..d160d7e 100644 --- a/src/function/builtin/SET.java +++ b/src/function/builtin/SET.java @@ -45,7 +45,7 @@ public class SET extends LispFunction { private SymbolTable findScopeOfSymbol(SExpression symbol) { SymbolTable table = executionContext.getScope(); - while (!isSymbolInTable(symbol, table) && !isGlobalTable(table)) + while (!isSymbolInTable(symbol, table) && !table.isGlobal()) table = table.getParent(); return table; @@ -55,8 +55,4 @@ public class SET extends LispFunction { return table.contains(symbol.toString()); } - private boolean isGlobalTable(SymbolTable table) { - return table.getParent() == null; - } - } diff --git a/src/function/builtin/special/RECUR.java b/src/function/builtin/special/RECUR.java new file mode 100644 index 0000000..01de941 --- /dev/null +++ b/src/function/builtin/special/RECUR.java @@ -0,0 +1,54 @@ +package function.builtin.special; + +import static function.builtin.EVAL.evalArgumentList; + +import error.LispException; +import function.ArgumentValidator; +import function.FunctionNames; +import function.LispFunction; +import function.LispSpecialFunction; +import sexpression.Cons; +import sexpression.SExpression; +import table.ExecutionContext; + +@FunctionNames({ "RECUR" }) +public class RECUR extends LispSpecialFunction { + + private ArgumentValidator argumentValidator; + private ExecutionContext executionContext; + + public RECUR(String name) { + this.argumentValidator = new ArgumentValidator(name); + this.executionContext = ExecutionContext.getInstance(); + } + + @Override + public SExpression call(Cons argumentList) { + argumentValidator.validate(argumentList); + + if (!executionContext.isInFunctionCall()) { + throw new RecurOutsideOfFunctionException(); + } + + LispFunction currentFunction = executionContext.getCurrentFunction(); + Cons recurArguments = argumentList; + + if (currentFunction.isArgumentListEvaluated()) + recurArguments = evalArgumentList(argumentList); + + executionContext.setRecur(); + + return recurArguments; + } + + public static class RecurOutsideOfFunctionException extends LispException { + + private static final long serialVersionUID = 1L; + + @Override + public String getMessage() { + return "recur called outide of function"; + } + } + +} diff --git a/src/table/ExecutionContext.java b/src/table/ExecutionContext.java index 71c1215..8303dab 100644 --- a/src/table/ExecutionContext.java +++ b/src/table/ExecutionContext.java @@ -1,5 +1,8 @@ package table; +import java.util.Stack; + +import function.LispFunction; import sexpression.SExpression; public class ExecutionContext { @@ -11,9 +14,13 @@ public class ExecutionContext { } private SymbolTable scope; + private Stack functionCalls; + private boolean recur; private ExecutionContext() { this.scope = new SymbolTable(); + this.functionCalls = new Stack<>(); + this.clearRecur(); } public SymbolTable getScope() { @@ -26,6 +33,8 @@ public class ExecutionContext { public void clearContext() { this.scope = new SymbolTable(); + this.functionCalls = new Stack<>(); + this.clearRecur(); } public SExpression lookupSymbolValue(String symbolName) { @@ -36,4 +45,32 @@ public class ExecutionContext { return null; } + public void pushFunctionCall(LispFunction function) { + functionCalls.push(function); + } + + public void popFunctionCall() { + functionCalls.pop(); + } + + public boolean isInFunctionCall() { + return !functionCalls.empty(); + } + + public LispFunction getCurrentFunction() { + return functionCalls.peek(); + } + + public boolean isRecur() { + return recur; + } + + public void setRecur() { + recur = true; + } + + public void clearRecur() { + recur = false; + } + } diff --git a/src/table/FunctionTable.java b/src/table/FunctionTable.java index ee2e026..ac69719 100644 --- a/src/table/FunctionTable.java +++ b/src/table/FunctionTable.java @@ -50,6 +50,7 @@ import function.builtin.special.LET_STAR; import function.builtin.special.OR; import function.builtin.special.PROGN; import function.builtin.special.QUOTE; +import function.builtin.special.RECUR; import function.builtin.special.SETQ; public class FunctionTable { @@ -94,6 +95,7 @@ public class FunctionTable { allBuiltIns.add(PRINT.class); allBuiltIns.add(PROGN.class); allBuiltIns.add(QUOTE.class); + allBuiltIns.add(RECUR.class); allBuiltIns.add(REST.class); allBuiltIns.add(SET.class); allBuiltIns.add(SETQ.class); diff --git a/src/table/SymbolTable.java b/src/table/SymbolTable.java index 8ab959b..cca909f 100644 --- a/src/table/SymbolTable.java +++ b/src/table/SymbolTable.java @@ -34,4 +34,8 @@ public class SymbolTable { return parent; } + public boolean isGlobal() { + return parent == null; + } + } diff --git a/test/function/builtin/special/RECURTest.java b/test/function/builtin/special/RECURTest.java new file mode 100644 index 0000000..6afc46d --- /dev/null +++ b/test/function/builtin/special/RECURTest.java @@ -0,0 +1,39 @@ +package function.builtin.special; + +import static testutil.TestUtilities.assertSExpressionsMatch; +import static testutil.TestUtilities.evaluateString; +import static testutil.TestUtilities.parseString; + +import org.junit.Test; + +import function.ArgumentValidator.BadArgumentTypeException; +import function.builtin.special.RECUR.RecurOutsideOfFunctionException; +import testutil.SymbolAndFunctionCleaner; + +public class RECURTest extends SymbolAndFunctionCleaner { + + @Test(expected = RecurOutsideOfFunctionException.class) + public void recurOutsideOfFunction_ThrowsException() { + evaluateString("(recur)"); + } + + @Test(expected = RecurOutsideOfFunctionException.class) + public void recurOutsideOfFunction_AfterFunctionCall_ThrowsException() { + evaluateString("(defun f (n) (if (= n 0) 'ZERO n))"); + evaluateString("(f 2)"); + evaluateString("(recur)"); + } + + @Test + public void recurCallsCurrentFunction() { + evaluateString("(defun tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + assertSExpressionsMatch(parseString("PASS"), evaluateString("(tail-recursive 900)")); + } + + @Test(expected = BadArgumentTypeException.class) + public void recurInSpecialFunction_DoesNotEvaluateArguments() { + evaluateString("(define-special tail-recursive (n) (if (= n 0) 'PASS (recur (- n 1))))"); + evaluateString("(tail-recursive 900)"); + } + +}