Created
December 27, 2017 20:15
-
-
Save jnape/89821e751917578cc8c7b69fc1e9ec77 to your computer and use it in GitHub Desktop.
Type-level encoding and optimization of tail recursive functions in Java with Lambda
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package spike; | |
import com.jnape.palatable.lambda.adt.choice.Choice2; | |
import com.jnape.palatable.lambda.adt.coproduct.CoProduct2; | |
import com.jnape.palatable.lambda.functions.Fn1; | |
import com.jnape.palatable.lambda.functions.Fn2; | |
import com.jnape.palatable.lambda.functions.builtin.fn2.Cons; | |
import com.jnape.palatable.lambda.functor.Bifunctor; | |
import java.util.Iterator; | |
import java.util.NoSuchElementException; | |
import java.util.function.Function; | |
import static com.jnape.palatable.lambda.adt.Maybe.just; | |
import static com.jnape.palatable.lambda.adt.Maybe.nothing; | |
import static com.jnape.palatable.lambda.adt.choice.Choice2.a; | |
import static com.jnape.palatable.lambda.adt.choice.Choice2.b; | |
import static com.jnape.palatable.lambda.adt.hlist.Tuple2.fill; | |
import static com.jnape.palatable.lambda.functions.builtin.fn1.Id.id; | |
import static com.jnape.palatable.lambda.functions.builtin.fn1.Last.last; | |
import static com.jnape.palatable.lambda.functions.builtin.fn2.Cons.cons; | |
import static com.jnape.palatable.lambda.functions.builtin.fn2.Partition.partition; | |
import static com.jnape.palatable.lambda.functions.builtin.fn2.Unfoldr.unfoldr; | |
import static spike.Spike.ContinuationStackFrame.recurse; | |
import static spike.Spike.ContinuationStackFrame.terminate; | |
import static spike.Spike.Corecursive.corecursive; | |
import static spike.Spike.OptimizedRecursiveCallStack.optimizedRecursiveCallStack; | |
import static spike.Spike.Recursive.recursive; | |
import static spike.Spike.Trampoline.trampoline; | |
public class Spike { | |
public static abstract class ContinuationStackFrame<A, B> implements CoProduct2<A, B, ContinuationStackFrame<A, B>>, Bifunctor<A, B, ContinuationStackFrame<?, ?>> { | |
@Override | |
public <C> ContinuationStackFrame<C, B> biMapL(Function<? super A, ? extends C> fn) { | |
throw new UnsupportedOperationException(); | |
} | |
@Override | |
public <C> ContinuationStackFrame<A, C> biMapR(Function<? super B, ? extends C> fn) { | |
throw new UnsupportedOperationException(); | |
} | |
@Override | |
public <C, D> ContinuationStackFrame<C, D> biMap(Function<? super A, ? extends C> lFn, | |
Function<? super B, ? extends D> rFn) { | |
throw new UnsupportedOperationException(); | |
} | |
public static <A, B> ContinuationStackFrame<A, B> recurse(A a) { | |
return new Recurse<>(a); | |
} | |
public static <A, B> ContinuationStackFrame<A, B> terminate(B b) { | |
return new Terminate<>(b); | |
} | |
private static final class Recurse<A, B> extends ContinuationStackFrame<A, B> { | |
private final A a; | |
private Recurse(A a) { | |
this.a = a; | |
} | |
@Override | |
public <R> R match(Function<? super A, ? extends R> aFn, Function<? super B, ? extends R> bFn) { | |
return aFn.apply(a); | |
} | |
@Override | |
public String toString() { | |
return "Recurse{" + | |
"a=" + a + | |
'}'; | |
} | |
} | |
private static final class Terminate<A, B> extends ContinuationStackFrame<A, B> { | |
private final B b; | |
private Terminate(B b) { | |
this.b = b; | |
} | |
@Override | |
public <R> R match(Function<? super A, ? extends R> aFn, Function<? super B, ? extends R> bFn) { | |
return bFn.apply(b); | |
} | |
@Override | |
public String toString() { | |
return "Terminate{" + | |
"b=" + b + | |
'}'; | |
} | |
} | |
} | |
public static final class Trampoline<A, B> implements Fn2<Function<? super A, ? extends CoProduct2<A, B, ?>>, A, B> { | |
private static final Trampoline INSTANCE = new Trampoline<>(); | |
@Override | |
public B apply(Function<? super A, ? extends CoProduct2<A, B, ?>> fn, A a) { | |
CoProduct2<? extends A, ? extends B, ?> next = fn.apply(a); | |
while (next.match(__ -> true, __ -> false)) | |
next = fn.apply(next.match(id(), __ -> null)); | |
return next.match(__ -> null, id()); | |
} | |
@SuppressWarnings("unchecked") | |
public static <A, B> Trampoline<A, B> trampoline() { | |
return INSTANCE; | |
} | |
public static <A, B> Fn1<A, B> trampoline(Function<? super A, ? extends CoProduct2<A, B, ?>> fn) { | |
return Trampoline.<A, B>trampoline().apply(fn); | |
} | |
public static <A, B> B trampoline(Function<? super A, ? extends CoProduct2<A, B, ?>> fn, A a) { | |
return trampoline(fn).apply(a); | |
} | |
} | |
public interface Recursive<A, B> extends Fn1<A, ContinuationStackFrame<A, B>> { | |
default Fn1<A, OptimizedRecursiveCallStack<A, B>> unroll() { | |
return a -> optimizedRecursiveCallStack(cons(recurse(a), fill(apply(a)) | |
.fmap(unfoldr(tc -> tc.match(next -> just(fill(apply(next))), __ -> nothing()))) | |
.into(Cons::cons))); | |
} | |
static <A, B> Recursive<A, B> recursive( | |
Function<? super A, ? extends CoProduct2<? extends A, ? extends B, ?>> f) { | |
return a -> f.apply(a).match(ContinuationStackFrame::recurse, ContinuationStackFrame::terminate); | |
} | |
default Fn1<A, B> trampoline() { | |
return Trampoline.trampoline(this); | |
} | |
} | |
public static final class Corecursive<A, B, C> implements Recursive<CoProduct2<A, B, ?>, C> { | |
private final Recursive<A, ? extends CoProduct2<? extends B, ? extends C, ?>> f; | |
private final Recursive<B, ? extends CoProduct2<? extends A, ? extends C, ?>> g; | |
private Corecursive(Recursive<A, ? extends CoProduct2<? extends B, ? extends C, ?>> f, | |
Recursive<B, ? extends CoProduct2<? extends A, ? extends C, ?>> g) { | |
this.f = f; | |
this.g = g; | |
} | |
@Override | |
public ContinuationStackFrame<CoProduct2<A, B, ?>, C> apply(CoProduct2<A, B, ?> ab) { | |
return ab.match(a -> f.apply(a).match(nextA -> recurse(Choice2.<A, B>a(nextA)), | |
bc -> bc.match(b -> recurse(Choice2.<A, B>b(b)), ContinuationStackFrame::terminate)), | |
b -> g.apply(b).match(nextB -> recurse(Choice2.<A, B>b(nextB)), | |
ac -> ac.match(a -> recurse(Choice2.<A, B>a(a)), ContinuationStackFrame::terminate))); | |
} | |
public static <A, B, C> Corecursive<A, B, C> corecursive( | |
Recursive<A, ? extends CoProduct2<? extends B, ? extends C, ?>> f, | |
Recursive<B, ? extends CoProduct2<? extends A, ? extends C, ?>> g | |
) { | |
return new Corecursive<>(f, g); | |
} | |
} | |
public static void main(String... args) { | |
Recursive<Integer, ContinuationStackFrame<Integer, Integer>> evens = recursive(i -> i == 1 | |
? terminate(terminate(i)) | |
: i % 2 == 0 | |
? recurse(i / 2) | |
: terminate(recurse(i))); | |
Recursive<Integer, Choice2<Integer, Integer>> odds = recursive(i -> i == 1 ? b(b(i)) : i % 2 == 1 ? a((i * 3) + 1) : b(a(i))); | |
Corecursive<Integer, Integer, Integer> conjecture = corecursive(evens, odds); | |
trampoline(corecursive(recursive(evens), recursive(odds))); | |
Choice2<Integer, Integer> arg = Choice2.a(9); | |
System.out.println(optimizedRecursiveCallStack(conjecture.unroll().apply(arg)).roll()); | |
System.out.println(trampoline(conjecture).apply(arg)); | |
} | |
public static final class OptimizedRecursiveCallStack<A, B> implements Iterable<ContinuationStackFrame<A, B>> { | |
private final Iterable<ContinuationStackFrame<A, B>> stack; | |
private OptimizedRecursiveCallStack(Iterable<ContinuationStackFrame<A, B>> stack) { | |
this.stack = stack; | |
} | |
public B roll() { | |
return last(partition(id(), stack)._2()).orElseThrow(NoSuchElementException::new); | |
} | |
@Override | |
public Iterator<ContinuationStackFrame<A, B>> iterator() { | |
return stack.iterator(); | |
} | |
public static <A, B> OptimizedRecursiveCallStack<A, B> optimizedRecursiveCallStack( | |
Iterable<ContinuationStackFrame<A, B>> stack) { | |
return new OptimizedRecursiveCallStack<>(stack); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment