Skip to content

Instantly share code, notes, and snippets.

@Garciat
Last active December 3, 2025 14:51
Show Gist options
  • Select an option

  • Save Garciat/a6ca3c9195d5b1d997badecd73282e38 to your computer and use it in GitHub Desktop.

Select an option

Save Garciat/a6ca3c9195d5b1d997badecd73282e38 to your computer and use it in GitHub Desktop.
package org.example;
import static java.lang.reflect.AccessFlag.PUBLIC;
import static java.lang.reflect.AccessFlag.STATIC;
import static java.util.Objects.requireNonNull;
import static org.example.TypeClassSystem.witness;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.example.ZeroOneMore.More;
import org.example.ZeroOneMore.One;
import org.example.ZeroOneMore.Zero;
// === Type Class System ===
@Retention(RetentionPolicy.RUNTIME)
@interface TypeClass {
@Retention(RetentionPolicy.RUNTIME)
@interface Witness {}
}
interface Ty<T> {
default Type type() {
return requireNonNull(
((ParameterizedType) getClass().getGenericInterfaces()[0]).getActualTypeArguments()[0]);
}
}
class TypeClassSystem {
private static final boolean DEBUG = false;
public static <T> T witness(Ty<T> ty) {
return switch (summon(ty.type())) {
case Either.Left<SummonError, Object>(SummonError error) ->
throw new WitnessResolutionException(error);
case Either.Right<SummonError, Object>(Object instance) -> {
@SuppressWarnings("unchecked")
T typedInstance = (T) instance;
yield typedInstance;
}
};
}
public static class WitnessResolutionException extends RuntimeException {
private WitnessResolutionException(SummonError error) {
super(error.format());
}
}
sealed interface SummonError {
record NotFound(Type target) implements SummonError {}
record Ambiguous(Type target) implements SummonError {}
record Nested(Type target, SummonError cause) implements SummonError {}
default String format() {
return switch (this) {
case NotFound(Type target) -> "No witness found for type: " + target.getTypeName();
case Ambiguous(Type target) -> "Multiple witnesses found for type: " + target.getTypeName();
case Nested(Type target, SummonError cause) ->
"While summoning witness for type: "
+ target.getTypeName()
+ "\nCaused by: "
+ cause.format().indent(2);
};
}
}
private static Either<SummonError, Object> summon(Type target) {
List<WitnessRule> rules = findRules(target);
if (DEBUG) {
System.out.println("[DEBUG] While summoning witness for type: " + target);
System.out.println("[DEBUG] Found rules:");
rules.forEach(rule -> System.out.println("[DEBUG] " + rule));
}
record Candidate(WitnessRule rule, Type[] requirements) {
static Optional<Candidate> of(WitnessRule rule, Type target) {
return rule.tryMatch(target).map(requirements -> new Candidate(rule, requirements));
}
}
List<Candidate> candidates =
rules.stream().flatMap(rule -> Candidate.of(rule, target).stream()).toList();
return switch (ZeroOneMore.of(candidates)) {
case One<Candidate>(Candidate(WitnessRule rule, Type[] requirements)) ->
summonAll(List.of(requirements))
.map(dependencies -> rule.instantiate(dependencies.toArray()))
.mapLeft(error -> new SummonError.Nested(target, error));
case Zero<Candidate> ignore -> Either.left(new SummonError.NotFound(target));
case More<Candidate> ignore -> Either.left(new SummonError.Ambiguous(target));
};
}
private static Either<SummonError, List<Object>> summonAll(List<Type> targets) {
return Either.traverse(targets, TypeClassSystem::summon);
}
static List<WitnessRule> findRules(Type target) {
return switch (target) {
case ParameterizedType pt when pt.getRawType() instanceof Class<?> cls && isTypeClass(cls) ->
Stream.concat(rulesOf(cls), rulesOf(pt.getActualTypeArguments())).toList();
default -> throw new RuntimeException("Type is not a type class: " + target);
};
}
private static Stream<WitnessRule> rulesOf(Type[] types) {
return Arrays.stream(types).flatMap(TypeClassSystem::rulesOf);
}
static Stream<WitnessRule> rulesOf(Type type) {
return switch (type) {
case Class<?> cls -> rulesOf(cls);
case ParameterizedType pt when pt.getRawType() instanceof Class<?> cls -> rulesOf(cls);
default -> Stream.empty();
};
}
private static Stream<WitnessRule> rulesOf(Class<?> cls) {
return Arrays.stream(cls.getDeclaredMethods())
.filter(TypeClassSystem::isWitnessMethod)
.map(WitnessRule::new);
}
private static boolean isTypeClass(Class<?> cls) {
return cls.isAnnotationPresent(TypeClass.class);
}
private static boolean isWitnessMethod(Method m) {
return m.accessFlags().contains(PUBLIC)
&& m.accessFlags().contains(STATIC)
&& m.isAnnotationPresent(TypeClass.Witness.class);
}
private static Optional<Map<Type, Type>> unify(Type base, Type reference) {
return switch (base) {
case TypeVariable<?> tv -> Optional.of(Map.of(tv, reference));
case ParameterizedType ptBase
when reference instanceof ParameterizedType ptReference
&& ptBase.getRawType().equals(ptReference.getRawType()) ->
Optionals.sequence(
Lists.zipExact(
ptBase.getActualTypeArguments(),
ptReference.getActualTypeArguments(),
TypeClassSystem::unify))
.map(Maps::concat);
case Type ty when ty.equals(reference) -> Optional.of(Map.of());
default -> Optional.empty();
};
}
private static Type substitute(Map<Type, Type> map, Type type) {
return switch (type) {
case TypeVariable<?> tv -> map.getOrDefault(tv, tv);
case ParameterizedType pt ->
new SomeParameterizedType(
substituteAll(map, pt.getActualTypeArguments()), pt.getRawType(), pt.getOwnerType());
default -> type;
};
}
private static Type[] substituteAll(Map<Type, Type> map, Type[] types) {
return Arrays.stream(types).map(type -> substitute(map, type)).toArray(Type[]::new);
}
private record WitnessRule(Method method) {
private Optional<Type[]> tryMatch(Type target) {
return unify(method.getGenericReturnType(), target)
.filter(map -> Maps.containsAllKeys(map, method.getTypeParameters()))
.map(map -> substituteAll(map, method.getGenericParameterTypes()));
}
private Object instantiate(Object[] dependencies) {
try {
return method.invoke(null, dependencies);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public String toString() {
return method.toGenericString();
}
}
private record SomeParameterizedType(
Type[] getActualTypeArguments, Type getRawType, Type getOwnerType)
implements ParameterizedType {
@Override
public String toString() {
return getRawType.getTypeName()
+ Arrays.stream(getActualTypeArguments)
.map(Type::getTypeName)
.collect(Collectors.joining(", ", "<", ">"));
}
}
}
// === Example Usage ===
public class Main {
public static void main(String[] args) {
Map<String, List<Optional<Integer>>> m1 =
Map.of(
"a", List.of(Optional.of(1), Optional.<Integer>empty()),
"b", List.of(Optional.of(2), Optional.of(3)));
System.out.printf("show(m1) = %s\n", Show.show(witness(new Ty<>() {}), m1));
List<Sum<Integer>> sums = List.of(new Sum<>(3), new Sum<>(5), new Sum<>(10));
System.out.printf(
"combineAll(%s) = %s\n", sums, Monoid.combineAll(witness(new Ty<>() {}), sums).value());
System.out.printf("eq(m1, m1) = %s\n", Eq.eq(witness(new Ty<>() {}), m1, m1));
Optional<Integer> m5 = Optional.of(5);
Optional<Integer> m10 = Optional.of(10);
System.out.printf(
"compare(%s, %s) = %s\n", m5, m10, Ord.compare(witness(new Ty<>() {}), m5, m10));
Arbitrary<Function<Optional<Integer>, List<Optional<Integer>>>> arbFunc =
witness(new Ty<>() {});
var f = arbFunc.arbitrary().generate(42L, 10);
System.out.println("f(10) = " + f.apply(Optional.of(5)));
}
}
// === Type Classes and Instances ===
@TypeClass
interface Show<A> {
String show(A a);
static <A> String show(Show<A> showA, A a) {
return showA.show(a);
}
@TypeClass.Witness
static Show<Integer> integerShow() {
return i -> Integer.toString(i);
}
@TypeClass.Witness
static Show<String> stringShow() {
return s -> "\"" + s + "\"";
}
@TypeClass.Witness
static <A> Show<Optional<A>> optionalShow(Show<A> showA) {
return optA -> optA.map(a -> "Some(" + showA.show(a) + ")").orElse("None");
}
@TypeClass.Witness
static <A> Show<List<A>> listShow(Show<A> showA) {
return listA -> listA.stream().map(showA::show).collect(Collectors.joining(", ", "[", "]"));
}
@TypeClass.Witness
static <K, V> Show<Map<K, V>> mapShow(Show<K> showK, Show<V> showV) {
return mapKV ->
mapKV.entrySet().stream()
.map(entry -> showK.show(entry.getKey()) + ": " + showV.show(entry.getValue()))
.collect(Collectors.joining(", ", "{", "}"));
}
}
@TypeClass
interface Eq<A> {
boolean eq(A a1, A a2);
static <A> boolean eq(Eq<A> eqA, A a1, A a2) {
return eqA.eq(a1, a2);
}
@TypeClass.Witness
static Eq<Integer> integerEq() {
return Integer::equals;
}
@TypeClass.Witness
static Eq<String> stringEq() {
return String::equals;
}
@TypeClass.Witness
static <A> Eq<Optional<A>> optionalEq(Eq<A> eqA) {
return (optA1, optA2) ->
optA1.isPresent() && optA2.isPresent()
? eqA.eq(optA1.get(), optA2.get())
: optA1.isEmpty() && optA2.isEmpty();
}
@TypeClass.Witness
static <A> Eq<List<A>> listEq(Eq<A> eqA) {
return (listA1, listA2) -> {
if (listA1.size() != listA2.size()) {
return false;
}
for (int i = 0; i < listA1.size(); i++) {
if (!eqA.eq(listA1.get(i), listA2.get(i))) {
return false;
}
}
return true;
};
}
@TypeClass.Witness
static <K, V> Eq<Map<K, V>> mapEq(Eq<K> eqK, Eq<V> eqV) {
return (map1, map2) -> {
if (map1.size() != map2.size()) {
return false;
}
for (Map.Entry<K, V> entry1 : map1.entrySet()) {
boolean found = false;
for (Map.Entry<K, V> entry2 : map2.entrySet()) {
if (eqK.eq(entry1.getKey(), entry2.getKey())
&& eqV.eq(entry1.getValue(), entry2.getValue())) {
found = true;
break;
}
}
if (!found) {
return false;
}
}
return true;
};
}
}
enum Ordering {
LT,
EQ,
GT
}
@TypeClass
interface Ord<A> extends Eq<A> {
Ordering compare(A a1, A a2);
@Override
default boolean eq(A a1, A a2) {
return compare(a1, a2) == Ordering.EQ;
}
static <A> Ordering compare(Ord<A> ordA, A a1, A a2) {
return ordA.compare(a1, a2);
}
static <A> boolean lt(Ord<A> ordA, A a1, A a2) {
return ordA.compare(a1, a2) == Ordering.LT;
}
@TypeClass.Witness
static Ord<Integer> integerOrd() {
return (a1, a2) -> a1 < a2 ? Ordering.LT : a1 > a2 ? Ordering.GT : Ordering.EQ;
}
@TypeClass.Witness
static <A> Ord<Optional<A>> optionalOrd(Ord<A> ordA) {
return (optA1, optA2) -> {
if (optA1.isPresent() && optA2.isPresent()) {
return ordA.compare(optA1.get(), optA2.get());
} else if (optA1.isEmpty() && optA2.isEmpty()) {
return Ordering.EQ;
} else if (optA1.isEmpty()) {
return Ordering.LT;
} else {
return Ordering.GT;
}
};
}
}
@TypeClass
interface Monoid<A> {
A combine(A a1, A a2);
A identity();
static <A> A combineAll(Monoid<A> monoid, List<A> elements) {
A result = monoid.identity();
for (A element : elements) {
result = monoid.combine(result, element);
}
return result;
}
@TypeClass.Witness
static Monoid<String> stringMonoid() {
return new Monoid<>() {
@Override
public String combine(String s1, String s2) {
return s1 + s2;
}
@Override
public String identity() {
return "";
}
};
}
}
@TypeClass
interface Num<A> {
A add(A a1, A a2);
A mul(A a1, A a2);
A zero();
A one();
@TypeClass.Witness
static Num<Integer> integerNum() {
return new Num<>() {
@Override
public Integer add(Integer a1, Integer a2) {
return a1 + a2;
}
@Override
public Integer mul(Integer a1, Integer a2) {
return a1 * a2;
}
@Override
public Integer zero() {
return 0;
}
@Override
public Integer one() {
return 1;
}
};
}
}
record Sum<A>(A value) {
@TypeClass.Witness
public static <A> Monoid<Sum<A>> monoid(Num<A> num) {
return new Monoid<>() {
@Override
public Sum<A> combine(Sum<A> s1, Sum<A> s2) {
return new Sum<>(num.add(s1.value(), s2.value()));
}
@Override
public Sum<A> identity() {
return new Sum<>(num.zero());
}
};
}
}
@TypeClass
interface RandomGen<G> {
Pair<Integer, G> next(G gen);
Pair<G, G> split(G gen);
@TypeClass.Witness
static RandomGen<java.util.Random> javaUtilRandomGen() {
return new RandomGen<>() {
@Override
public Pair<Integer, java.util.Random> next(java.util.Random gen) {
return new Pair<>(gen.nextInt(), gen);
}
@Override
public Pair<java.util.Random, java.util.Random> split(java.util.Random gen) {
java.util.Random gen1 = new java.util.Random(gen.nextLong());
java.util.Random gen2 = new java.util.Random(gen.nextLong());
return new Pair<>(gen1, gen2);
}
};
}
}
@TypeClass
interface Random<A> {
<G> Pair<A, G> random(RandomGen<G> randomGen, G gen);
@TypeClass.Witness
static Random<Integer> integerRandom() {
return new Random<>() {
@Override
public <G> Pair<Integer, G> random(RandomGen<G> randomGen, G gen) {
return randomGen.next(gen);
}
};
}
}
@FunctionalInterface
interface Gen<A> {
A generate(long seed, int size);
default <B> Gen<B> map(Function<A, B> f) {
return (seed, size) -> f.apply(generate(seed, size));
}
// TODO: This is a naive implementation; in a real implementation, the seed management would be
// more sophisticated.
default <B> Gen<B> flatMap(Function<A, Gen<B>> f) {
return (seed, size) -> f.apply(generate(seed, size)).generate(seed + 1, size);
}
default Gen<A> variant(int n) {
return (seed, size) -> generate(seed + n, size);
}
default Gen<List<A>> listOf() {
return sized(size -> chooseInt(0, size).flatMap(this::vectorOf));
}
default Gen<List<A>> vectorOf(int length) {
return (seed, size) -> {
List<A> result = new ArrayList<>();
for (int i = 0; i < length; i++) {
result.add(generate(seed + i, size));
}
return result;
};
}
static Gen<Integer> chooseInt(int low, int high) {
return (seed, size) -> new java.util.Random(seed).nextInt(low, high);
}
static <A> Gen<A> sized(Function<Integer, Gen<A>> gen) {
return (seed, size) -> gen.apply(size).generate(seed, size);
}
}
@TypeClass
interface Arbitrary<A> {
Gen<A> arbitrary();
@TypeClass.Witness
static Arbitrary<Integer> integerArbitrary() {
return () -> Gen.chooseInt(Integer.MIN_VALUE, Integer.MAX_VALUE);
}
@TypeClass.Witness
static <A> Arbitrary<Optional<A>> optionalArbitrary(Arbitrary<A> arbA) {
return () -> {
Gen<A> genA = arbA.arbitrary();
return (seed, size) -> {
Gen<Integer> genBool = Gen.chooseInt(0, 2);
if (genBool.generate(seed, size) == 0) {
return Optional.of(genA.generate(seed + 1, size));
} else {
return Optional.empty();
}
};
};
}
@TypeClass.Witness
static <A> Arbitrary<List<A>> listArbitrary(Arbitrary<A> arbA) {
return () -> arbA.arbitrary().listOf();
}
@TypeClass.Witness
static <A, B> Arbitrary<Function<A, B>> functionArbitrary(
CoArbitrary<A> coarb, Arbitrary<B> arbB) {
return () -> {
Gen<B> genB = arbB.arbitrary();
return (seed, size) -> a -> coarb.coarbitrary(a, genB).generate(seed, size);
};
}
}
@TypeClass
interface CoArbitrary<A> {
<B> Gen<B> coarbitrary(A a, Gen<B> genB);
@TypeClass.Witness
static CoArbitrary<Integer> integerCoArbitrary() {
return new CoArbitrary<>() {
@Override
public <B> Gen<B> coarbitrary(Integer a, Gen<B> genB) {
return genB.variant(a);
}
};
}
@TypeClass.Witness
static <A> CoArbitrary<Optional<A>> optionalCoArbitrary(CoArbitrary<A> coarbA) {
return new CoArbitrary<>() {
@Override
public <B> Gen<B> coarbitrary(Optional<A> optA, Gen<B> genB) {
if (optA.isPresent()) {
return coarbA.coarbitrary(optA.get(), genB).variant(1);
} else {
return genB.variant(0);
}
}
};
}
@TypeClass.Witness
static <A> CoArbitrary<List<A>> listCoArbitrary(CoArbitrary<A> coarbA) {
return new CoArbitrary<>() {
@Override
public <B> Gen<B> coarbitrary(List<A> listA, Gen<B> genB) {
Gen<B> resultGen = genB.variant(listA.size());
for (A a : listA) {
resultGen = coarbA.coarbitrary(a, resultGen).variant(1);
}
return resultGen;
}
};
}
@TypeClass.Witness
static <A, B> CoArbitrary<Function<A, B>> functionCoArbitrary(
Arbitrary<A> arbA, CoArbitrary<B> coarbB) {
return new CoArbitrary<>() {
@Override
public <C> Gen<C> coarbitrary(Function<A, B> f, Gen<C> genC) {
return Arbitrary.listArbitrary(arbA)
.arbitrary()
.flatMap(xs -> CoArbitrary.listCoArbitrary(coarbB).coarbitrary(Lists.map(xs, f), genC));
}
};
}
}
// === Utility Types ===
record Pair<A, B>(A first, B second) {}
sealed interface ZeroOneMore<A> {
record Zero<A>() implements ZeroOneMore<A> {}
record One<A>(A value) implements ZeroOneMore<A> {}
record More<A>(List<A> values) implements ZeroOneMore<A> {}
static <A> ZeroOneMore<A> of(List<A> list) {
return switch (list.size()) {
case 0 -> new Zero<>();
case 1 -> new One<>(list.getFirst());
default -> new More<>(list);
};
}
}
sealed interface Either<L, R> {
record Left<L, R>(L value) implements Either<L, R> {}
record Right<L, R>(R value) implements Either<L, R> {}
static <L, R> Either<L, R> left(L value) {
return new Left<>(value);
}
static <L, R> Either<L, R> right(R value) {
return new Right<>(value);
}
default <X> Either<L, X> map(Function<? super R, ? extends X> f) {
return fold(Either::left, f.andThen(Either::right));
}
default <X> Either<X, R> mapLeft(Function<? super L, ? extends X> f) {
return fold(f.andThen(Either::left), Either::right);
}
default <A> A fold(
Function<? super L, ? extends A> fLeft, Function<? super R, ? extends A> fRight) {
return switch (this) {
case Left<L, R>(L value) -> fLeft.apply(value);
case Right<L, R>(R value) -> fRight.apply(value);
};
}
static <L, R> Either<L, List<R>> sequence(List<Either<L, R>> eithers) {
List<R> results = new ArrayList<>();
for (Either<L, R> either : eithers) {
if (either instanceof Left<L, R> left) {
return left(left.value());
} else if (either instanceof Right<L, R> right) {
results.add(right.value());
}
}
return right(results);
}
static <A, L, R> Either<L, List<R>> traverse(List<A> list, Function<? super A, Either<L, R>> f) {
return sequence(Lists.map(list, f));
}
}
// === Utility classes ===
class Lists {
public static <A, B> List<B> map(List<A> list, Function<? super A, ? extends B> f) {
return list.stream().map(f).collect(Collectors.toList());
}
public static <T, R> List<R> zipExact(
T[] a1, T[] a2, BiFunction<? super T, ? super T, ? extends R> f) {
if (a1.length != a2.length) {
throw new IllegalArgumentException("Arrays must have the same length");
}
return IntStream.range(0, a1.length).<R>mapToObj(i -> f.apply(a1[i], a2[i])).toList();
}
}
class Maps {
public static <K, V> Map<K, V> concat(List<Map<K, V>> maps) {
return maps.stream().reduce(Map.of(), Maps::merge);
}
public static <K, V> Map<K, V> merge(Map<K, V> m1, Map<K, V> m2) {
Map<K, V> result = new HashMap<>(m1);
for (Map.Entry<K, V> entry : m2.entrySet()) {
if (result.put(entry.getKey(), entry.getValue()) != null) {
throw new IllegalArgumentException("Duplicate key: " + entry.getKey());
}
}
return result;
}
public static <K, V> boolean containsAllKeys(Map<K, V> map, K[] typeParameters) {
return Arrays.stream(typeParameters).allMatch(map::containsKey);
}
}
class Optionals {
public static <A, B, C> Optional<C> apply(
Optional<A> optA, Optional<B> optB, BiFunction<? super A, ? super B, ? extends C> f) {
if (optA.isPresent() && optB.isPresent()) {
return Optional.of(f.apply(optA.get(), optB.get()));
} else {
return Optional.empty();
}
}
public static <A> Optional<List<A>> sequence(List<Optional<A>> optionals) {
List<A> results = new ArrayList<>();
for (Optional<A> optional : optionals) {
if (optional.isEmpty()) {
return Optional.empty();
} else {
results.add(optional.get());
}
}
return Optional.of(results);
}
public static <A, B> Optional<B> fold(
List<A> optionals, B identity, BiFunction<? super B, ? super A, Optional<? extends B>> f) {
B result = identity;
for (A a : optionals) {
Optional<? extends B> next = f.apply(result, a);
if (next.isEmpty()) {
return Optional.empty();
} else {
result = next.get();
}
}
return Optional.of(result);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment