Skip to content

Instantly share code, notes, and snippets.

@SegFaultAX
Last active October 6, 2018 18:21
Show Gist options
  • Save SegFaultAX/4d619249a4712092a3071567fa92e7d8 to your computer and use it in GitHub Desktop.
Save SegFaultAX/4d619249a4712092a3071567fa92e7d8 to your computer and use it in GitHub Desktop.
Functional composable applicative streaming folds [Java]
import java.util.Collection;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
public interface Fold<Acc, From, To> {
Acc empty();
Acc step(Acc acc, From x);
To combine(Acc acc);
static <A, F, T> Fold<A, F, T> using(Supplier<A> empty, BiFunction<A, F, A> step, Function<A, T> combine) {
return new Fold<>() {
@Override
public A empty() {
return empty.get();
}
@Override
public A step(A a, F x) {
return step.apply(a, x);
}
@Override
public T combine(A a) {
return combine.apply(a);
}
};
}
static <A, F> Fold<A, F, A> using(A empty, BiFunction<A, F, A> step) {
return using(() -> empty, step, Function.identity());
}
static <A, F, T> T fold(Fold<A, F, T> f, Collection<F> xs) {
A acc = f.empty();
for (F x : xs) {
acc = f.step(acc, x);
}
return f.combine(acc);
}
default To fold(Collection<From> xs) {
return fold(this, xs);
}
static <A, F, T1, T2> Fold<A, F, T2> map(Fold<A, F, T1> f, Function<T1, T2> fn) {
return new Fold<>() {
@Override
public A empty() {
return f.empty();
}
@Override
public A step(A acc, F x) {
return f.step(acc, x);
}
@Override
public T2 combine(A acc) {
return fn.apply(f.combine(acc));
}
};
}
default <R> Fold<Acc, From, R> map(Function<To, R> fn) {
return Fold.map(this, fn);
}
static <F, T> Fold<Void, F, T> pure(T v) {
return new Fold<>() {
@Override
public Void empty() {
return null;
}
@Override
public Void step(Void acc, F x) {
return null;
}
@Override
public T combine(Void acc) {
return v;
}
};
}
static <A1, A2, F, T1, T2> Fold<Pair<A1, A2>, F, T2> apply(Fold<A1, F, Function<T1, T2>> f, Fold<A2, F, T1> v) {
return new Fold<>() {
@Override
public Pair<A1, A2> empty() {
return Pair.of(f.empty(), v.empty());
}
@Override
public Pair<A1, A2> step(Pair<A1, A2> acc, F x) {
return Pair.of(f.step(acc.left(), x), v.step(acc.right(), x));
}
@Override
public T2 combine(Pair<A1, A2> acc) {
return f.combine(acc.left()).apply(v.combine(acc.right()));
}
};
}
static <A1, A2, T1, T2, F, R> Fold<Pair<A1, A2>, F, R> liftA2(
BiFunction<T1, T2, R> fn,
Fold<A1, F, T1> f1,
Fold<A2, F, T2> f2) {
Function<T1, Function<T2, R>> curried = a -> b -> fn.apply(a, b);
return Fold.apply(f1.map(curried), f2);
}
static <A1, A2, T1, T2, F, R> Fold<Pair<A1, A2>, F, R> liftA2(
Function<T1, Function<T2, R>> fn,
Fold<A1, F, T1> f1,
Fold<A2, F, T2> f2) {
return Fold.apply(f1.map(fn), f2);
}
}
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.function.Function;
import static org.assertj.core.api.Assertions.assertThat;
public class FoldTest {
private final Fold<Integer, Integer, Integer> sumFold =
Fold.using(0, (Integer a, Integer b) -> a + b);
private final Fold<Integer, Integer, Integer> lenFold =
Fold.using(0, (a, b) -> a + 1);
@Test
public void testSimpleFold() {
assertThat(sumFold.fold(List.of(1, 2, 3, 4))).isEqualTo(10);
assertThat(Fold.fold(lenFold, List.of(1, 2, 3, 4))).isEqualTo(4);
}
@Test
public void testFoldFunctor() {
assertThat(Fold.map(sumFold, v -> Integer.toString(v)).fold(List.of(1, 2, 3, 4))).isEqualTo("10");
}
@Test
public void testFoldApplicativePure() {
assertThat(Fold.pure(10).fold(List.of(1, 2))).isEqualTo(10);
}
@Test
public void testFoldApplicativeApply() {
Fold<Integer, Integer, Function<Integer, Pair<Integer, Integer>>> pairingSum =
Fold.map(sumFold, Pair.curried());
assertThat(Fold.apply(pairingSum, lenFold).fold(List.of(1, 2, 3, 4))).isEqualTo(Pair.of(10, 4));
}
@Test
public void testFoldApplicativeLift() {
Fold<Pair<Integer, Integer>, Integer, Pair<Integer, Integer>> sumAndLen =
Fold.liftA2(Pair::of, sumFold, lenFold);
assertThat(sumAndLen.fold(List.of(1, 2, 3, 4))).isEqualTo(Pair.of(10, 4));
}
}
import java.util.Objects;
import java.util.function.Function;
public class Pair<A, B> {
private final A a;
private final B b;
public Pair(A a, B b) {
this.a = a;
this.b = b;
}
public static <A, B> Pair<A, B> of(A a, B b) {
return new Pair<>(a, b);
}
public static <A, B> Function<A, Function<B, Pair<A, B>>> curried() {
return a -> b -> of(a, b);
}
public A left() {
return a;
}
public B right() {
return b;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Pair<?, ?> pair = (Pair<?, ?>) o;
return Objects.equals(a, pair.a) &&
Objects.equals(b, pair.b);
}
@Override
public int hashCode() {
return Objects.hash(a, b);
}
@Override
public String toString() {
return "Pair{" +
"a=" + a +
", b=" + b +
'}';
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment