Refactor mutual tail call code

This commit is contained in:
Mike Cifelli 2019-07-18 19:29:05 -04:00
parent 24cf141aba
commit bfc8ccd25f
7 changed files with 42 additions and 42 deletions

View File

@ -1,12 +1,12 @@
package recursion package recursion
interface TailCall<T> { interface MutualTailCall<T> {
fun isComplete() = false fun isTerminal() = false
fun apply(): TailCall<T> fun apply(): MutualTailCall<T>
fun result(): T = throw UnsupportedOperationException() fun result(): T = throw UnsupportedOperationException()
operator fun invoke() = generateSequence(this) { it.apply() } operator fun invoke() = generateSequence(this) { it.apply() }
.first { it.isComplete() } .first { it.isTerminal() }
.result() .result()
} }

View File

@ -0,0 +1,14 @@
package recursion
object MutualTailCalls {
fun <T> recursiveCall(nextCall: () -> MutualTailCall<T>) = object : MutualTailCall<T> {
override fun apply() = nextCall()
}
fun <T> terminalValue(value: T) = object : MutualTailCall<T> {
override fun isTerminal() = true
override fun result() = value
override fun apply() = throw UnsupportedOperationException()
}
}

View File

@ -1,14 +0,0 @@
package recursion
object TailCalls {
fun <T> tailCall(nextCall: () -> TailCall<T>) = object : TailCall<T> {
override fun apply() = nextCall()
}
fun <T> done(value: T) = object : TailCall<T> {
override fun isComplete() = true
override fun result() = value
override fun apply() = throw UnsupportedOperationException()
}
}

View File

@ -1,9 +1,9 @@
package sexpression package sexpression
import recursion.TailCall import recursion.MutualTailCall
import recursion.TailCalls.done import recursion.MutualTailCalls.terminalValue
import recursion.TailCalls.tailCall import recursion.MutualTailCalls.recursiveCall
@DisplayName("list") @DisplayName("list")
open class Cons(open var first: SExpression, open var rest: SExpression) : SExpression(), Iterable<Cons> { open class Cons(open var first: SExpression, open var rest: SExpression) : SExpression(), Iterable<Cons> {
@ -14,16 +14,16 @@ open class Cons(open var first: SExpression, open var rest: SExpression) : SExpr
return toStringTailRecursive(StringBuilder("(")).invoke() return toStringTailRecursive(StringBuilder("(")).invoke()
} }
private fun toStringTailRecursive(leadingString: StringBuilder): TailCall<String> { private fun toStringTailRecursive(leadingString: StringBuilder): MutualTailCall<String> {
leadingString.append(first.toString()) leadingString.append(first.toString())
if (rest.isNull) if (rest.isNull)
return done(leadingString.append(")").toString()) return terminalValue(leadingString.append(")").toString())
else if (rest.isCons) { else if (rest.isCons) {
return tailCall { (rest as Cons).toStringTailRecursive(leadingString.append(" ")) } return recursiveCall { (rest as Cons).toStringTailRecursive(leadingString.append(" ")) }
} }
return done(leadingString.append(" . " + rest.toString() + ")").toString()) return terminalValue(leadingString.append(" . $rest)").toString())
} }
override fun iterator(): Iterator<Cons> = ConsIterator(this) override fun iterator(): Iterator<Cons> = ConsIterator(this)

View File

@ -2,8 +2,8 @@ package token
import error.LineColumnException import error.LineColumnException
import file.FilePosition import file.FilePosition
import recursion.TailCall import recursion.MutualTailCall
import recursion.TailCalls.done import recursion.MutualTailCalls.terminalValue
import sexpression.Cons import sexpression.Cons
import sexpression.Nil import sexpression.Nil
import sexpression.SExpression import sexpression.SExpression
@ -14,12 +14,12 @@ class RightParenthesis(text: String, position: FilePosition) : Token(text, posit
throw StartsWithRightParenthesisException(position) throw StartsWithRightParenthesisException(position)
} }
override fun parseListTail(getNextToken: () -> Token): TailCall<Cons> { override fun parseListTail(getNextToken: () -> Token): MutualTailCall<Cons> {
return done(Nil) return terminalValue(Nil)
} }
override fun parseListTailRecursive(start: Cons, end: Cons, getNextToken: () -> Token): TailCall<Cons> { override fun parseListTailRecursive(start: Cons, end: Cons, getNextToken: () -> Token): MutualTailCall<Cons> {
return done(start) return terminalValue(start)
} }
class StartsWithRightParenthesisException(position: FilePosition) : LineColumnException(position) { class StartsWithRightParenthesisException(position: FilePosition) : LineColumnException(position) {

View File

@ -1,8 +1,8 @@
package token package token
import file.FilePosition import file.FilePosition
import recursion.TailCall import recursion.MutualTailCall
import recursion.TailCalls.tailCall import recursion.MutualTailCalls.recursiveCall
import sexpression.Cons import sexpression.Cons
import sexpression.Nil import sexpression.Nil
import sexpression.SExpression import sexpression.SExpression
@ -19,18 +19,18 @@ abstract class Token(val text: String, val position: FilePosition) {
abstract fun parseSExpression(getNextToken: () -> Token): SExpression abstract fun parseSExpression(getNextToken: () -> Token): SExpression
open fun parseListTail(getNextToken: () -> Token): TailCall<Cons> { open fun parseListTail(getNextToken: () -> Token): MutualTailCall<Cons> {
val firstCons = Cons(parseSExpression(getNextToken), Nil) val firstCons = Cons(parseSExpression(getNextToken), Nil)
val next = getNextToken() val next = getNextToken()
return tailCall { next.parseListTailRecursive(firstCons, firstCons, getNextToken) } return recursiveCall { next.parseListTailRecursive(firstCons, firstCons, getNextToken) }
} }
protected open fun parseListTailRecursive(start: Cons, end: Cons, getNextToken: () -> Token): TailCall<Cons> { protected open fun parseListTailRecursive(start: Cons, end: Cons, getNextToken: () -> Token): MutualTailCall<Cons> {
val newEnd = Cons(parseSExpression(getNextToken), Nil) val newEnd = Cons(parseSExpression(getNextToken), Nil)
val next = getNextToken() val next = getNextToken()
end.rest = newEnd end.rest = newEnd
return tailCall { next.parseListTailRecursive(start, newEnd, getNextToken) } return recursiveCall { next.parseListTailRecursive(start, newEnd, getNextToken) }
} }
} }

View File

@ -2,16 +2,16 @@ package recursion
import org.junit.jupiter.api.Assertions.assertThrows import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import recursion.TailCalls.done import recursion.MutualTailCalls.terminalValue
import testutil.LispTestInstance import testutil.LispTestInstance
@LispTestInstance @LispTestInstance
class TailCallTest { class MutualTailCallTest {
@Test @Test
fun `tailCall does not support result`() { fun `tailCall does not support result`() {
val tailCall = object : TailCall<Nothing?> { val tailCall = object : MutualTailCall<Nothing?> {
override fun apply() = done(null) override fun apply() = terminalValue(null)
} }
assertThrows(UnsupportedOperationException::class.java) { tailCall.result() } assertThrows(UnsupportedOperationException::class.java) { tailCall.result() }
@ -19,6 +19,6 @@ class TailCallTest {
@Test @Test
fun `done does not support apply`() { fun `done does not support apply`() {
assertThrows(UnsupportedOperationException::class.java) { done(null).apply() } assertThrows(UnsupportedOperationException::class.java) { terminalValue(null).apply() }
} }
} }