Skip to content

Instantly share code, notes, and snippets.

@re-xyr
Last active March 15, 2025 01:17
Show Gist options
  • Save re-xyr/c069436906a2070bb96e64203511c682 to your computer and use it in GitHub Desktop.
Save re-xyr/c069436906a2070bb96e64203511c682 to your computer and use it in GitHub Desktop.
import Unification.Companion.convertibleTo
import Unification.Companion.unify
import org.junit.jupiter.api.Test
// Refs
sealed interface Def
class Local(val name: String) : Def {
override fun toString(): String = name
}
class Meta(val name: String, var solution: Term? = null) : Def {
companion object {
var metaCounter: Int = 0
}
val id: Int = metaCounter++
override fun toString(): String = "(?$name($id) := $solution)"
}
class Class(val name: String, val params: List<Param>) : Def {
override fun toString(): String = "class $name $params"
}
class Instance(val name: String, val params: List<Param>, val premises: List<Premise>, val cls: CallCls) : Def {
override fun toString(): String = "inst $name $params $premises: $cls"
}
// Terms
sealed interface Term
data class CallCls(val cls: Class, val args: List<Term>) : Term {
override fun toString(): String = "${cls.name}$args"
}
data class CallInst(val inst: Instance, val args: List<Term>, val premises: List<CallInst>) : Term {
override fun toString(): String = "${inst.name}$args$premises"
}
data class Flex(val ref: Meta) : Term {
override fun toString(): String = ref.toString()
}
data class Rigid(val ref: Local) : Term {
override fun toString(): String = ref.toString()
}
object Type : Term {
override fun toString(): String = "Type"
}
// Params
data class Param(val ref: Local, val type: Term) {
override fun toString(): String = "($ref : $type)"
}
data class Premise(val ref: Local, val type: CallCls) {
override fun toString(): String = "($ref : $type)"
}
// Substitution
fun Term.subst(from: Def, to: Term): Term = subst(mapOf(from to to))
fun Term.subst(mapping: Map<out Def, Term>): Term = when (this) {
is CallCls -> copy(args = args.map { it.subst(mapping) })
is CallInst -> copy(args = args.map { it.subst(mapping) }, premises = premises.map { it.subst(mapping) as CallInst })
is Flex -> mapping.getOrDefault(ref, this)
is Rigid -> mapping.getOrDefault(ref, this)
Type -> this
}
// Zonk
fun Term.zonk(): Term = when (this) {
is CallCls -> copy(args = args.map { it.zonk() })
is CallInst -> copy(args = args.map { it.zonk() }, premises = premises.map { it.zonk() as CallInst })
is Flex -> ref.solution ?: this
is Rigid -> this
Type -> this
}
// Get all unresolved metas in a term
fun Term.unsolvedMetas(): List<Meta> = when (this) {
is CallCls -> args.flatMap { it.unsolvedMetas() }
is CallInst -> (args + premises).flatMap { it.unsolvedMetas() }
is Flex -> if (ref.solution == null) listOf(ref) else listOf()
is Rigid -> listOf()
Type -> listOf()
}
// Reduction
fun Term.reduce(): Term = when (this) {
is Flex -> ref.solution ?: this
is CallCls -> copy(args = args.map { it.reduce() })
is CallInst -> copy(args = args.map { it.reduce() }, premises = premises.map { it.reduce() as CallInst })
is Rigid -> this
Type -> this
}
// Unification
class Unification private constructor() {
private val flexes: MutableMap<Meta, Term> = mutableMapOf()
val flex: Map<Meta, Term> get() = flexes
private fun basicUnify(preInst: Term, preProb: Term, toCall: (Term, Term) -> Boolean): Boolean {
val (inst, prob) = preInst.reduce() to preProb.reduce()
when {
inst is CallCls && prob is CallCls -> {
if (inst.cls != prob.cls) return false
for ((i, p) in inst.args.zip(prob.args))
if (!toCall(i, p)) return false
return true
}
inst is CallInst && prob is CallInst -> {
if (inst.inst != prob.inst) return false
for ((i, p) in inst.args.zip(prob.args))
if (!toCall(i, p)) return false
return true
}
inst is Flex && prob is Flex && inst.ref === prob.ref -> return true
inst is Rigid && prob is Rigid && inst.ref === prob.ref -> return true
inst is Type && prob is Type -> return true
}
return false
}
// This version of unification finds solution to meta
// Used in finding matching instances
private fun unify(preInst: Term, preProb: Term): Boolean {
val (inst, prob) = preInst.reduce() to preProb.reduce()
if (basicUnify(inst, prob, this::unify)) return true
when {
inst is Flex ->
return if (inst.ref !in flexes) {
flexes[inst.ref] = prob
true
} else basicUnify(flexes[inst.ref]!!, prob, this::unify)
prob is Flex ->
return if (prob.ref !in flexes) {
flexes[prob.ref] = inst
true
} else basicUnify(flexes[prob.ref]!!, inst, this::unify)
}
return false
}
// This version of unification regards all unsolved metas to be equal
// Used in finding table entries
private fun convertible(preInst: Term, preProb: Term): Boolean {
val (inst, prob) = preInst.reduce() to preProb.reduce()
if (basicUnify(inst, prob, this::convertible)) return true
when {
inst is Flex && prob is Flex ->
return if (inst.ref !in flexes) {
flexes[inst.ref] = prob
true
} else basicUnify(flexes[inst.ref]!!, prob, this::convertible)
}
return false
}
companion object {
infix fun Term.unify(rhs: Term): Unification? {
val u = Unification()
if (u.unify(this, rhs)) return u
return null
}
infix fun Term.convertibleTo(rhs: Term): Unification? {
val u = Unification()
if (u.convertible(this, rhs)) return u
return null
}
}
}
// Instance Search Forest
sealed interface Node {
val children: List<Node>
fun solutions(): List<Result> =
children.filterIsInstance<Result>() + children.map { it.solutions() }.flatten()
}
data class Generator(
val query: CallCls,
val subst: Map<Meta, Flex>,
override val children: MutableList<Node> = mutableListOf(),
) : Node
data class Consumer(
val premises: List<Premise>,
val apply: CallInst,
override val children: MutableList<Node> = mutableListOf(),
) : Node
data class Result(
val solution: CallInst,
val subst: Map<Meta, Term>,
) : Node {
override val children: List<Node> = listOf()
}
// Resolution Algorithm
class Querier(
val instances: Map<Class, List<Instance>>,
) {
private val table: MutableList<Generator> = mutableListOf()
fun query(query: CallCls, layer: Int = 0): List<Result> {
// Check if exists
var existing = table.find { (it.query convertibleTo query) != null }
if (existing != null) {
println("${" ".repeat(layer)}Exsiting query: $query")
} else {
// Metas could be changed but table entries should always reflect the meta value at that time
// So, we zonk all solved metas (to prevent further changes) and replace all unsolved metas with fresh ones.
// However we still use the original query term in the 'real' query process.
// So this mapping is also stored in Generator node, when we retrieve the solution we need to substitute with it.
val mapping = query.unsolvedMetas().associateWith { Flex(Meta(it.name)) }
val nominalQuery = query.zonk().subst(mapping) as CallCls
println("${" ".repeat(layer)}Generated: $nominalQuery")
val gen = Generator(nominalQuery, mapping)
table += gen
// Go through instances
for (inst in instances.getOrDefault(query.cls, listOf())) {
val next = inst.match(query, layer + 1)
if (next != null) gen.children += next
}
existing = gen
}
val sols = existing.solutions()
val mapping = (existing.query convertibleTo query)!!.flex
val realSols = sols.map { sol -> sol.copy(
solution = sol.solution.subst(existing.subst).subst(mapping) as CallInst,
subst = sol.subst
.mapKeys { existing.subst[it.key]?.ref ?: it.key } // Substitute to correspond to the nominal query
.mapValues { it.value.subst(existing.subst) }
.mapKeys { (mapping[it.key] as Flex?)?.ref ?: it.key } // Deal with alpha-equivalence
.mapValues { it.value.subst(mapping) }
) }
println("${" ".repeat(layer)}Finished $query:\n ${" ".repeat(layer)}${realSols}")
return realSols
}
fun Instance.match(q: CallCls, layer: Int): Node? {
// Create metas for parameters
val mapping =
params.associate { it.ref to Flex(Meta(it.ref.name)) }
// Subst metas for params in premises to get subgoals
val subgoals =
premises.map { it.copy(type = it.type.subst(mapping) as CallCls) }
// Also subst meta for result type, and unify with query.
val metaCls = cls.subst(mapping) as CallCls
val unifier = metaCls unify q ?: return null
// These are temporary solutions to meta (specific to current consumer node).
return withMetas(unifier.flex) {
consume(subgoals, CallInst(this, params.map { mapping[it.ref]!! }, listOf()), layer)
}
}
fun consume(subgoals: List<Premise>, apply: CallInst, layer: Int): Node {
return if (subgoals.isEmpty()) {
// No subgoal left, a solution (Result node) is produced with current meta substitution.
println("${" ".repeat(layer)}Result: $apply")
Result(apply, solutionPool.toMap())
} else {
println("${" ".repeat(layer)}Consume: $subgoals")
// Create the consumer node
val cons = Consumer(subgoals, apply)
// We need to solve the foremost subgoal.
val subgoal = subgoals.first()
val results = query(subgoal.type, layer + 1)
for (res in results) {
// For each solution, try to solve the remaining subgoals.
withMetas(res.subst) {
cons.children += consume(
// Substitute in telescope
subgoals.drop(1).map { it.copy(type = it.type.subst(subgoal.ref, res.solution) as CallCls) },
// Add solved subgoal to param list
apply.copy(premises = apply.premises + res.solution),
layer + 1
)
}
}
cons
}
}
}
// Current temporary solutions. Guaranteed that any meta will occur at most once.
val solutionPool: MutableMap<Meta, Term> = mutableMapOf()
// Temporarily set solution to a meta and unset after an action
fun <T> withMetas(rep: Map<Meta, Term>, f: () -> T): T {
val prev = rep.keys.associateWith { it.solution }
for ((ref, tm) in rep) {
ref.solution = tm
solutionPool[ref] = tm
}
try { return f() } finally {
for ((ref, tm) in rep) {
ref.solution = prev[ref]
solutionPool -= ref
}
}
}
// ====== TESTS ======
class QuerierTest {
@Test
fun x() {
// class R (M N : Type)
val R = Class("R", listOf(Param(Local("M"), Type), Param(Local("N"), Type)))
val A = Local("A")
val B = Local("B")
val C = Local("C")
val D = Local("D")
val X = Local("X")
val Y = Local("Y")
val Z = Local("Z")
val instances = mapOf(
R to listOf(
// instance I1 : R A B
Instance("I1", listOf(), listOf(), CallCls(R, listOf(Rigid(A), Rigid(B)))),
// instance I2 : R A C
Instance("I2", listOf(), listOf(), CallCls(R, listOf(Rigid(A), Rigid(C)))),
// instance I3: R C D
Instance("I3", listOf(), listOf(), CallCls(R, listOf(Rigid(C), Rigid(D)))),
// instance I4 {X Y Z : Type} (p : R X Y) (q: R Y Z) : R X Z
Instance("I4",
listOf(Param(X, Type), Param(Y, Type), Param(Z, Type)),
listOf(Premise(Local("p"), CallCls(R, listOf(Rigid(X), Rigid(Y)))), Premise(Local("q"), CallCls(R, listOf(Rigid(Y), Rigid(Z))))),
CallCls(R, listOf(Rigid(X), Rigid(Z))),
),
),
)
var querier = Querier(instances)
// query R A D
querier.query(CallCls(R, listOf(Rigid(A), Rigid(D))))
querier = Querier(instances)
// query R A ??
querier.query(CallCls(R, listOf(Rigid(A), Flex(Meta("?")))))
}
}
@ice1000
Copy link

ice1000 commented Jun 17, 2021

Hot take: flexes should be called flices

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment