Refactor mutual tail call code
This commit is contained in:
parent
24cf141aba
commit
bfc8ccd25f
@ -1,12 +1,12 @@
|
||||
package recursion
|
||||
|
||||
interface TailCall<T> {
|
||||
interface MutualTailCall<T> {
|
||||
|
||||
fun isComplete() = false
|
||||
fun apply(): TailCall<T>
|
||||
fun isTerminal() = false
|
||||
fun apply(): MutualTailCall<T>
|
||||
fun result(): T = throw UnsupportedOperationException()
|
||||
|
||||
operator fun invoke() = generateSequence(this) { it.apply() }
|
||||
.first { it.isComplete() }
|
||||
.first { it.isTerminal() }
|
||||
.result()
|
||||
}
|
14
src/main/kotlin/recursion/MutualTailCalls.kt
Normal file
14
src/main/kotlin/recursion/MutualTailCalls.kt
Normal file
@ -0,0 +1,14 @@
|
||||
package recursion
|
||||
|
||||
object MutualTailCalls {
|
||||
|
||||
fun <T> recursiveCall(nextCall: () -> MutualTailCall<T>) = object : MutualTailCall<T> {
|
||||
override fun apply() = nextCall()
|
||||
}
|
||||
|
||||
fun <T> terminalValue(value: T) = object : MutualTailCall<T> {
|
||||
override fun isTerminal() = true
|
||||
override fun result() = value
|
||||
override fun apply() = throw UnsupportedOperationException()
|
||||
}
|
||||
}
|
@ -1,14 +0,0 @@
|
||||
package recursion
|
||||
|
||||
object TailCalls {
|
||||
|
||||
fun <T> tailCall(nextCall: () -> TailCall<T>) = object : TailCall<T> {
|
||||
override fun apply() = nextCall()
|
||||
}
|
||||
|
||||
fun <T> done(value: T) = object : TailCall<T> {
|
||||
override fun isComplete() = true
|
||||
override fun result() = value
|
||||
override fun apply() = throw UnsupportedOperationException()
|
||||
}
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
package sexpression
|
||||
|
||||
import recursion.TailCall
|
||||
import recursion.MutualTailCall
|
||||
|
||||
import recursion.TailCalls.done
|
||||
import recursion.TailCalls.tailCall
|
||||
import recursion.MutualTailCalls.terminalValue
|
||||
import recursion.MutualTailCalls.recursiveCall
|
||||
|
||||
@DisplayName("list")
|
||||
open class Cons(open var first: SExpression, open var rest: SExpression) : SExpression(), Iterable<Cons> {
|
||||
@ -14,16 +14,16 @@ open class Cons(open var first: SExpression, open var rest: SExpression) : SExpr
|
||||
return toStringTailRecursive(StringBuilder("(")).invoke()
|
||||
}
|
||||
|
||||
private fun toStringTailRecursive(leadingString: StringBuilder): TailCall<String> {
|
||||
private fun toStringTailRecursive(leadingString: StringBuilder): MutualTailCall<String> {
|
||||
leadingString.append(first.toString())
|
||||
|
||||
if (rest.isNull)
|
||||
return done(leadingString.append(")").toString())
|
||||
return terminalValue(leadingString.append(")").toString())
|
||||
else if (rest.isCons) {
|
||||
return tailCall { (rest as Cons).toStringTailRecursive(leadingString.append(" ")) }
|
||||
return recursiveCall { (rest as Cons).toStringTailRecursive(leadingString.append(" ")) }
|
||||
}
|
||||
|
||||
return done(leadingString.append(" . " + rest.toString() + ")").toString())
|
||||
return terminalValue(leadingString.append(" . $rest)").toString())
|
||||
}
|
||||
|
||||
override fun iterator(): Iterator<Cons> = ConsIterator(this)
|
||||
|
@ -2,8 +2,8 @@ package token
|
||||
|
||||
import error.LineColumnException
|
||||
import file.FilePosition
|
||||
import recursion.TailCall
|
||||
import recursion.TailCalls.done
|
||||
import recursion.MutualTailCall
|
||||
import recursion.MutualTailCalls.terminalValue
|
||||
import sexpression.Cons
|
||||
import sexpression.Nil
|
||||
import sexpression.SExpression
|
||||
@ -14,12 +14,12 @@ class RightParenthesis(text: String, position: FilePosition) : Token(text, posit
|
||||
throw StartsWithRightParenthesisException(position)
|
||||
}
|
||||
|
||||
override fun parseListTail(getNextToken: () -> Token): TailCall<Cons> {
|
||||
return done(Nil)
|
||||
override fun parseListTail(getNextToken: () -> Token): MutualTailCall<Cons> {
|
||||
return terminalValue(Nil)
|
||||
}
|
||||
|
||||
override fun parseListTailRecursive(start: Cons, end: Cons, getNextToken: () -> Token): TailCall<Cons> {
|
||||
return done(start)
|
||||
override fun parseListTailRecursive(start: Cons, end: Cons, getNextToken: () -> Token): MutualTailCall<Cons> {
|
||||
return terminalValue(start)
|
||||
}
|
||||
|
||||
class StartsWithRightParenthesisException(position: FilePosition) : LineColumnException(position) {
|
||||
|
@ -1,8 +1,8 @@
|
||||
package token
|
||||
|
||||
import file.FilePosition
|
||||
import recursion.TailCall
|
||||
import recursion.TailCalls.tailCall
|
||||
import recursion.MutualTailCall
|
||||
import recursion.MutualTailCalls.recursiveCall
|
||||
import sexpression.Cons
|
||||
import sexpression.Nil
|
||||
import sexpression.SExpression
|
||||
@ -19,18 +19,18 @@ abstract class Token(val text: String, val position: FilePosition) {
|
||||
|
||||
abstract fun parseSExpression(getNextToken: () -> Token): SExpression
|
||||
|
||||
open fun parseListTail(getNextToken: () -> Token): TailCall<Cons> {
|
||||
open fun parseListTail(getNextToken: () -> Token): MutualTailCall<Cons> {
|
||||
val firstCons = Cons(parseSExpression(getNextToken), Nil)
|
||||
val next = getNextToken()
|
||||
|
||||
return tailCall { next.parseListTailRecursive(firstCons, firstCons, getNextToken) }
|
||||
return recursiveCall { next.parseListTailRecursive(firstCons, firstCons, getNextToken) }
|
||||
}
|
||||
|
||||
protected open fun parseListTailRecursive(start: Cons, end: Cons, getNextToken: () -> Token): TailCall<Cons> {
|
||||
protected open fun parseListTailRecursive(start: Cons, end: Cons, getNextToken: () -> Token): MutualTailCall<Cons> {
|
||||
val newEnd = Cons(parseSExpression(getNextToken), Nil)
|
||||
val next = getNextToken()
|
||||
end.rest = newEnd
|
||||
|
||||
return tailCall { next.parseListTailRecursive(start, newEnd, getNextToken) }
|
||||
return recursiveCall { next.parseListTailRecursive(start, newEnd, getNextToken) }
|
||||
}
|
||||
}
|
||||
|
@ -2,16 +2,16 @@ package recursion
|
||||
|
||||
import org.junit.jupiter.api.Assertions.assertThrows
|
||||
import org.junit.jupiter.api.Test
|
||||
import recursion.TailCalls.done
|
||||
import recursion.MutualTailCalls.terminalValue
|
||||
import testutil.LispTestInstance
|
||||
|
||||
@LispTestInstance
|
||||
class TailCallTest {
|
||||
class MutualTailCallTest {
|
||||
|
||||
@Test
|
||||
fun `tailCall does not support result`() {
|
||||
val tailCall = object : TailCall<Nothing?> {
|
||||
override fun apply() = done(null)
|
||||
val tailCall = object : MutualTailCall<Nothing?> {
|
||||
override fun apply() = terminalValue(null)
|
||||
}
|
||||
|
||||
assertThrows(UnsupportedOperationException::class.java) { tailCall.result() }
|
||||
@ -19,6 +19,6 @@ class TailCallTest {
|
||||
|
||||
@Test
|
||||
fun `done does not support apply`() {
|
||||
assertThrows(UnsupportedOperationException::class.java) { done(null).apply() }
|
||||
assertThrows(UnsupportedOperationException::class.java) { terminalValue(null).apply() }
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user