Skip to content

Instantly share code, notes, and snippets.

@aruld
Forked from jonbodner/YFact.java
Created October 27, 2012 20:12
Show Gist options
  • Save aruld/3965968 to your computer and use it in GitHub Desktop.
Save aruld/3965968 to your computer and use it in GitHub Desktop.
Generic Y Combinator in Java 8 using lambdas
//based on code from http://www.arcfn.com/2009/03/y-combinator-in-arc-and-java.html and the generic version https://gist.github.com/2571928
class YFact {
// T function returning a T
// T -> T
public static interface Func<T> {
T apply(T n);
}
// Higher-order function returning a T function
// F: F -> (T -> T)
private static interface FuncToTFunc<T> {
Func<T> apply(FuncToTFunc<T> x);
}
//Next comes the meat. We define the Y combinator, apply it to the factorial input function, and apply the result to the input argument. The result is the factorial.
// Formulation : λr.(λf.(f f)) λf.(r λx.((f f) x))
public static <T> Func<T> Y(final Func<Func<T>> r) {
return ((FuncToTFunc<T>) f -> f.apply(f))
.apply(
f -> r.apply(
x -> f.apply(f).apply(x)));
}
public static void main(String args[]) {
System.out.println(
// Y combinator
Y(
// Recursive function generator
new Func<Func<Integer>>() {
public Func<Integer> apply(final Func<Integer> f) {
return n -> n == 0 ? 1 : n * f.apply(n - 1);
}
}
).apply(
// Argument
Integer.parseInt(args[0])));
}
}
@spullara
Copy link

You can simplify the Y() call as:

System.out.println(
    // Y combinator
    Y((Func<Integer> f) -> (Integer n) -> n == 0 ? 1 : n * f.apply(n - 1)).apply(
        // Argument
        Integer.parseInt(args[0])));

I'm not sure about the conversion to method references but I don't see an easy way to do it.

@spullara
Copy link

If you don't like the type in there, you can switch to if/then. The :? operator has issues with type inference.

         Y((Func<Integer> f) -> n -> { if (n == 0) return 1; else return n * f.apply(n - 1); }).apply(

@aruld
Copy link
Author

aruld commented Nov 21, 2012

Nice Sam! Yea, I run into type error with the latter approach. But, it can be fixed by providing a type hint like before.

((final Func<Integer> f) -> (Integer n) -> { if (n == 0) return 1; else return n * f.apply(n - 1); })

Btw, I like Brian's version which looks much better.

class Y {

  interface SelfApplicable<T> {
    T apply(SelfApplicable<T> a);
  }

  interface Func<X, Y> {
    Y apply(X x);
  }

  public static void main(String[] args) {
    // The Y combinator

    SelfApplicable<Func<Func<Func<Integer, Integer>, Func<Integer, Integer>>, Func<Integer, Integer>>> Y =
        y -> f -> x -> f.apply(y.apply(y).apply(f)).apply(x);

    // The fixed point generator
    Func<Func<Func<Integer, Integer>, Func<Integer, Integer>>, Func<Integer, Integer>> Fix = Y.apply(Y);

    // The higher order function describing factorial
    Func<Func<Integer, Integer>, Func<Integer, Integer>> F = fac -> x -> x == 0 ? 1 : x * fac.apply(x - 1);

    // The factorial function itself
    Func<Integer, Integer> factorial = Fix.apply(F);

    for (int i = 0; i < 12; i++) {
      System.out.println(factorial.apply(i));
    }
  }
}

@benhardy
Copy link

I found this easier to understand by naming the function interfaces uniquely as the order got higher.

import java.util.function.Function;

public class YCombinator {
    interface Hopper<T,R> {
        Function<T,R> hop(Function<T,R> inFunc);
    }
    interface Fixer<T,R> {
        Function<T,R> fix(Hopper<T,R> toFix);
    }
    interface SelfApply<X> {
        X self(SelfApply<X> me);
    }
    static <T,R> SelfApply<Fixer<T,R>> combinator() {
        return me -> hopper -> input -> hopper.hop(me.self(me).fix(hopper)).apply(input);
    }
    static <T,R> Fixer<T,R> fixer() {
        final SelfApply<Fixer<T,R>> y = combinator();
        return y.self(y);
    }
    public static void main(String[] args) {
        final Hopper<Integer,Integer> factorialDefinition = deeper ->
                n -> (n > 0 ? n * deeper.apply(n - 1) : 1);
        final Fixer<Integer, Integer> fixer = fixer();
        final Function<Integer,Integer> factorial = fixer.fix(factorialDefinition);

        for (int i = 0; i < 12; i++) {
            System.out.printf("%3d => %d\n", i, factorial.apply(i));
        }
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment