Last active
March 15, 2025 01:17
-
-
Save re-xyr/c069436906a2070bb96e64203511c682 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Unification.Companion.convertibleTo | |
import Unification.Companion.unify | |
import org.junit.jupiter.api.Test | |
// Refs | |
sealed interface Def | |
class Local(val name: String) : Def { | |
override fun toString(): String = name | |
} | |
class Meta(val name: String, var solution: Term? = null) : Def { | |
companion object { | |
var metaCounter: Int = 0 | |
} | |
val id: Int = metaCounter++ | |
override fun toString(): String = "(?$name($id) := $solution)" | |
} | |
class Class(val name: String, val params: List<Param>) : Def { | |
override fun toString(): String = "class $name $params" | |
} | |
class Instance(val name: String, val params: List<Param>, val premises: List<Premise>, val cls: CallCls) : Def { | |
override fun toString(): String = "inst $name $params $premises: $cls" | |
} | |
// Terms | |
sealed interface Term | |
data class CallCls(val cls: Class, val args: List<Term>) : Term { | |
override fun toString(): String = "${cls.name}$args" | |
} | |
data class CallInst(val inst: Instance, val args: List<Term>, val premises: List<CallInst>) : Term { | |
override fun toString(): String = "${inst.name}$args$premises" | |
} | |
data class Flex(val ref: Meta) : Term { | |
override fun toString(): String = ref.toString() | |
} | |
data class Rigid(val ref: Local) : Term { | |
override fun toString(): String = ref.toString() | |
} | |
object Type : Term { | |
override fun toString(): String = "Type" | |
} | |
// Params | |
data class Param(val ref: Local, val type: Term) { | |
override fun toString(): String = "($ref : $type)" | |
} | |
data class Premise(val ref: Local, val type: CallCls) { | |
override fun toString(): String = "($ref : $type)" | |
} | |
// Substitution | |
fun Term.subst(from: Def, to: Term): Term = subst(mapOf(from to to)) | |
fun Term.subst(mapping: Map<out Def, Term>): Term = when (this) { | |
is CallCls -> copy(args = args.map { it.subst(mapping) }) | |
is CallInst -> copy(args = args.map { it.subst(mapping) }, premises = premises.map { it.subst(mapping) as CallInst }) | |
is Flex -> mapping.getOrDefault(ref, this) | |
is Rigid -> mapping.getOrDefault(ref, this) | |
Type -> this | |
} | |
// Zonk | |
fun Term.zonk(): Term = when (this) { | |
is CallCls -> copy(args = args.map { it.zonk() }) | |
is CallInst -> copy(args = args.map { it.zonk() }, premises = premises.map { it.zonk() as CallInst }) | |
is Flex -> ref.solution ?: this | |
is Rigid -> this | |
Type -> this | |
} | |
// Get all unresolved metas in a term | |
fun Term.unsolvedMetas(): List<Meta> = when (this) { | |
is CallCls -> args.flatMap { it.unsolvedMetas() } | |
is CallInst -> (args + premises).flatMap { it.unsolvedMetas() } | |
is Flex -> if (ref.solution == null) listOf(ref) else listOf() | |
is Rigid -> listOf() | |
Type -> listOf() | |
} | |
// Reduction | |
fun Term.reduce(): Term = when (this) { | |
is Flex -> ref.solution ?: this | |
is CallCls -> copy(args = args.map { it.reduce() }) | |
is CallInst -> copy(args = args.map { it.reduce() }, premises = premises.map { it.reduce() as CallInst }) | |
is Rigid -> this | |
Type -> this | |
} | |
// Unification | |
class Unification private constructor() { | |
private val flexes: MutableMap<Meta, Term> = mutableMapOf() | |
val flex: Map<Meta, Term> get() = flexes | |
private fun basicUnify(preInst: Term, preProb: Term, toCall: (Term, Term) -> Boolean): Boolean { | |
val (inst, prob) = preInst.reduce() to preProb.reduce() | |
when { | |
inst is CallCls && prob is CallCls -> { | |
if (inst.cls != prob.cls) return false | |
for ((i, p) in inst.args.zip(prob.args)) | |
if (!toCall(i, p)) return false | |
return true | |
} | |
inst is CallInst && prob is CallInst -> { | |
if (inst.inst != prob.inst) return false | |
for ((i, p) in inst.args.zip(prob.args)) | |
if (!toCall(i, p)) return false | |
return true | |
} | |
inst is Flex && prob is Flex && inst.ref === prob.ref -> return true | |
inst is Rigid && prob is Rigid && inst.ref === prob.ref -> return true | |
inst is Type && prob is Type -> return true | |
} | |
return false | |
} | |
// This version of unification finds solution to meta | |
// Used in finding matching instances | |
private fun unify(preInst: Term, preProb: Term): Boolean { | |
val (inst, prob) = preInst.reduce() to preProb.reduce() | |
if (basicUnify(inst, prob, this::unify)) return true | |
when { | |
inst is Flex -> | |
return if (inst.ref !in flexes) { | |
flexes[inst.ref] = prob | |
true | |
} else basicUnify(flexes[inst.ref]!!, prob, this::unify) | |
prob is Flex -> | |
return if (prob.ref !in flexes) { | |
flexes[prob.ref] = inst | |
true | |
} else basicUnify(flexes[prob.ref]!!, inst, this::unify) | |
} | |
return false | |
} | |
// This version of unification regards all unsolved metas to be equal | |
// Used in finding table entries | |
private fun convertible(preInst: Term, preProb: Term): Boolean { | |
val (inst, prob) = preInst.reduce() to preProb.reduce() | |
if (basicUnify(inst, prob, this::convertible)) return true | |
when { | |
inst is Flex && prob is Flex -> | |
return if (inst.ref !in flexes) { | |
flexes[inst.ref] = prob | |
true | |
} else basicUnify(flexes[inst.ref]!!, prob, this::convertible) | |
} | |
return false | |
} | |
companion object { | |
infix fun Term.unify(rhs: Term): Unification? { | |
val u = Unification() | |
if (u.unify(this, rhs)) return u | |
return null | |
} | |
infix fun Term.convertibleTo(rhs: Term): Unification? { | |
val u = Unification() | |
if (u.convertible(this, rhs)) return u | |
return null | |
} | |
} | |
} | |
// Instance Search Forest | |
sealed interface Node { | |
val children: List<Node> | |
fun solutions(): List<Result> = | |
children.filterIsInstance<Result>() + children.map { it.solutions() }.flatten() | |
} | |
data class Generator( | |
val query: CallCls, | |
val subst: Map<Meta, Flex>, | |
override val children: MutableList<Node> = mutableListOf(), | |
) : Node | |
data class Consumer( | |
val premises: List<Premise>, | |
val apply: CallInst, | |
override val children: MutableList<Node> = mutableListOf(), | |
) : Node | |
data class Result( | |
val solution: CallInst, | |
val subst: Map<Meta, Term>, | |
) : Node { | |
override val children: List<Node> = listOf() | |
} | |
// Resolution Algorithm | |
class Querier( | |
val instances: Map<Class, List<Instance>>, | |
) { | |
private val table: MutableList<Generator> = mutableListOf() | |
fun query(query: CallCls, layer: Int = 0): List<Result> { | |
// Check if exists | |
var existing = table.find { (it.query convertibleTo query) != null } | |
if (existing != null) { | |
println("${" ".repeat(layer)}Exsiting query: $query") | |
} else { | |
// Metas could be changed but table entries should always reflect the meta value at that time | |
// So, we zonk all solved metas (to prevent further changes) and replace all unsolved metas with fresh ones. | |
// However we still use the original query term in the 'real' query process. | |
// So this mapping is also stored in Generator node, when we retrieve the solution we need to substitute with it. | |
val mapping = query.unsolvedMetas().associateWith { Flex(Meta(it.name)) } | |
val nominalQuery = query.zonk().subst(mapping) as CallCls | |
println("${" ".repeat(layer)}Generated: $nominalQuery") | |
val gen = Generator(nominalQuery, mapping) | |
table += gen | |
// Go through instances | |
for (inst in instances.getOrDefault(query.cls, listOf())) { | |
val next = inst.match(query, layer + 1) | |
if (next != null) gen.children += next | |
} | |
existing = gen | |
} | |
val sols = existing.solutions() | |
val mapping = (existing.query convertibleTo query)!!.flex | |
val realSols = sols.map { sol -> sol.copy( | |
solution = sol.solution.subst(existing.subst).subst(mapping) as CallInst, | |
subst = sol.subst | |
.mapKeys { existing.subst[it.key]?.ref ?: it.key } // Substitute to correspond to the nominal query | |
.mapValues { it.value.subst(existing.subst) } | |
.mapKeys { (mapping[it.key] as Flex?)?.ref ?: it.key } // Deal with alpha-equivalence | |
.mapValues { it.value.subst(mapping) } | |
) } | |
println("${" ".repeat(layer)}Finished $query:\n ${" ".repeat(layer)}${realSols}") | |
return realSols | |
} | |
fun Instance.match(q: CallCls, layer: Int): Node? { | |
// Create metas for parameters | |
val mapping = | |
params.associate { it.ref to Flex(Meta(it.ref.name)) } | |
// Subst metas for params in premises to get subgoals | |
val subgoals = | |
premises.map { it.copy(type = it.type.subst(mapping) as CallCls) } | |
// Also subst meta for result type, and unify with query. | |
val metaCls = cls.subst(mapping) as CallCls | |
val unifier = metaCls unify q ?: return null | |
// These are temporary solutions to meta (specific to current consumer node). | |
return withMetas(unifier.flex) { | |
consume(subgoals, CallInst(this, params.map { mapping[it.ref]!! }, listOf()), layer) | |
} | |
} | |
fun consume(subgoals: List<Premise>, apply: CallInst, layer: Int): Node { | |
return if (subgoals.isEmpty()) { | |
// No subgoal left, a solution (Result node) is produced with current meta substitution. | |
println("${" ".repeat(layer)}Result: $apply") | |
Result(apply, solutionPool.toMap()) | |
} else { | |
println("${" ".repeat(layer)}Consume: $subgoals") | |
// Create the consumer node | |
val cons = Consumer(subgoals, apply) | |
// We need to solve the foremost subgoal. | |
val subgoal = subgoals.first() | |
val results = query(subgoal.type, layer + 1) | |
for (res in results) { | |
// For each solution, try to solve the remaining subgoals. | |
withMetas(res.subst) { | |
cons.children += consume( | |
// Substitute in telescope | |
subgoals.drop(1).map { it.copy(type = it.type.subst(subgoal.ref, res.solution) as CallCls) }, | |
// Add solved subgoal to param list | |
apply.copy(premises = apply.premises + res.solution), | |
layer + 1 | |
) | |
} | |
} | |
cons | |
} | |
} | |
} | |
// Current temporary solutions. Guaranteed that any meta will occur at most once. | |
val solutionPool: MutableMap<Meta, Term> = mutableMapOf() | |
// Temporarily set solution to a meta and unset after an action | |
fun <T> withMetas(rep: Map<Meta, Term>, f: () -> T): T { | |
val prev = rep.keys.associateWith { it.solution } | |
for ((ref, tm) in rep) { | |
ref.solution = tm | |
solutionPool[ref] = tm | |
} | |
try { return f() } finally { | |
for ((ref, tm) in rep) { | |
ref.solution = prev[ref] | |
solutionPool -= ref | |
} | |
} | |
} | |
// ====== TESTS ====== | |
class QuerierTest { | |
@Test | |
fun x() { | |
// class R (M N : Type) | |
val R = Class("R", listOf(Param(Local("M"), Type), Param(Local("N"), Type))) | |
val A = Local("A") | |
val B = Local("B") | |
val C = Local("C") | |
val D = Local("D") | |
val X = Local("X") | |
val Y = Local("Y") | |
val Z = Local("Z") | |
val instances = mapOf( | |
R to listOf( | |
// instance I1 : R A B | |
Instance("I1", listOf(), listOf(), CallCls(R, listOf(Rigid(A), Rigid(B)))), | |
// instance I2 : R A C | |
Instance("I2", listOf(), listOf(), CallCls(R, listOf(Rigid(A), Rigid(C)))), | |
// instance I3: R C D | |
Instance("I3", listOf(), listOf(), CallCls(R, listOf(Rigid(C), Rigid(D)))), | |
// instance I4 {X Y Z : Type} (p : R X Y) (q: R Y Z) : R X Z | |
Instance("I4", | |
listOf(Param(X, Type), Param(Y, Type), Param(Z, Type)), | |
listOf(Premise(Local("p"), CallCls(R, listOf(Rigid(X), Rigid(Y)))), Premise(Local("q"), CallCls(R, listOf(Rigid(Y), Rigid(Z))))), | |
CallCls(R, listOf(Rigid(X), Rigid(Z))), | |
), | |
), | |
) | |
var querier = Querier(instances) | |
// query R A D | |
querier.query(CallCls(R, listOf(Rigid(A), Rigid(D)))) | |
querier = Querier(instances) | |
// query R A ?? | |
querier.query(CallCls(R, listOf(Rigid(A), Flex(Meta("?"))))) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hot take: flexes should be called flices