Skip to content

Instantly share code, notes, and snippets.

@makenowjust
Created August 15, 2024 12:18
Show Gist options
  • Save makenowjust/924e04c33d9760b6194101e4ce26de48 to your computer and use it in GitHub Desktop.
Save makenowjust/924e04c33d9760b6194101e4ce26de48 to your computer and use it in GitHub Desktop.
import scala.util.boundary
import scala.collection.mutable
import scala.compiletime.ops.double
import scala.annotation.tailrec
final case class Mealy[Q, I, O](
initial: Q,
transition: Map[(Q, I), (O, Q)]
):
type State = Q
def run(is: Seq[I]): (Seq[O], Q) =
val os = Seq.newBuilder[O]
var q = initial
for i <- is do
val (o, q1) = transition((q, i))
os.addOne(o)
q = q1
(os.result(), q)
trait SUL[I, O]:
def trace(is: Seq[I]): Seq[O]
def findCounterExample(h: Mealy[?, I, O]): Option[Seq[I]]
def inputAlphabet: Set[I]
def outputAlphabet: Set[O]
final case class OTree[I, O](
edges: Map[I, (O, OTree[I, O])] = Map.empty[I, (O, OTree[I, O])]
):
def get(is: Seq[I]): Option[OTree[I, O]] =
is.headOption match
case None => Some(this)
case Some(i) => edges.get(i).flatMap(_._2.get(is.tail))
def inserted(ios: Seq[(I, O)]): OTree[I, O] =
ios.headOption match
case None => this
case Some((i, o)) =>
edges.get(i) match
case None =>
val t = OTree(Map.empty).inserted(ios.tail)
OTree(edges ++ Map(i -> (o, t)))
case Some((_, t)) =>
OTree(edges ++ Map(i -> (o, t.inserted(ios.tail))))
infix def apart(that: OTree[I, O]): Option[Seq[I]] =
val queue = mutable.Queue.empty[(Seq[I], OTree[I, O], OTree[I, O])]
queue.enqueue((Seq.empty, this, that))
boundary:
while queue.nonEmpty do
val (is, t1, t2) = queue.dequeue
for
i <- t1.edges.keySet ++ t2.edges.keySet
((o1, t1), (o2, t2)) <- t1.edges.get(i) zip t2.edges.get(i)
do
if o1 != o2 then boundary.break(Some(is ++ Seq(i)))
else queue.enqueue((is ++ Seq(i), t1, t2))
None
final class LSharp[I, O](val sul: SUL[I, O]):
type S = Seq[I]
var root: OTree[I, O] = OTree()
val basis = mutable.Set.empty[S]
val frontier = mutable.Map.empty[S, Set[S]]
def learn(): Mealy[S, I, O] =
basis.add(Seq.empty)
while true do
var updated = false
println("Update frontier")
updateFrontier()
println(s"basis=${basis}, frontier=${frontier}")
println("Apply Rule1")
updated = applyRule1()
if updated then println("Rule1 is applied.")
if !updated then
println("Apply Rule2")
updated = applyRule2()
if updated then println("Rule2 is applied.")
if !updated then
println("Apply Rule3")
updated = applyRule3()
if updated then println("Rule3 is applied.")
if !updated then
println("Apply Rule4")
applyRule4() match
case Some(h) => return h
case None => ()
sys.error("unreachable")
def outputQuery(is: Seq[I]): Unit =
val os = sul.trace(is)
root = root.inserted(is.zip(os))
def addBasis(q: S): Unit =
println(s"Add ${q} to basis")
basis.add(q)
val qt = root.get(q).get
frontier.mapValuesInPlace: (p, qs) =>
val pt = root.get(p).get
if (qt apart pt).isEmpty then qs ++ Set(q)
else qs
def addFrontier(p: S): Unit =
val pt = root.get(p).get
val qs = basis.iterator.filter: q =>
val qt = root.get(q).get
(qt apart pt).isEmpty
frontier(p) = qs.toSet
def updateFrontier(): Unit =
frontier.mapValuesInPlace: (p, qs) =>
val pt = root.get(p).get
qs.filter: q =>
val qt = root.get(q).get
(qt apart pt).isEmpty
def buildHypothesis(): Mealy[S, I, O] =
val transition = Map.newBuilder[(S, I), (O, S)]
val initial = Seq.empty[I]
for q <- basis do
val qt = root.get(q).get
for i <- sul.inputAlphabet do
val (o, _) = qt.edges(i)
var p = q ++ Seq(i)
if !basis.contains(p) then
p = frontier(p).head
transition.addOne((q, i), (o, p))
Mealy(initial, transition.result())
def checkConsistency(h: Mealy[S, I, O]): Option[S] =
val queue = mutable.Queue.empty[(S, OTree[I, O], S)]
queue.enqueue((Seq.empty, root, h.initial))
while queue.nonEmpty do
val (s, t, q) = queue.dequeue()
val qt = root.get(q).get
(t apart qt) match
case None =>
for (i, (_, t1)) <- t.edges do
queue.addOne((s ++ Seq(i), t1, h.transition((q, i))._2))
case Some(_) => return Some(s)
None
def applyRule1(): Boolean =
val isolatedStates = frontier.iterator
.filter(_._2.isEmpty)
.map(_._1)
.toSeq
// Note that this `frontier(q).isEmpty` is necessary
// because `frontier` can be updated on this loop and
// new non-isolated states can be appeared.
for q <- isolatedStates; if frontier(q).isEmpty do
frontier.remove(q)
addBasis(q)
isolatedStates.nonEmpty
def applyRule2(): Boolean =
val incompletePairs = basis.iterator
.flatMap(q => sul.inputAlphabet.map((q, _)))
.filter: (q, c) =>
val p = q ++ Seq(c)
root.get(p).isEmpty || !basis.contains(p) && !frontier.contains(p)
.toSeq
for (q, c) <- incompletePairs do
val p = q ++ Seq(c)
if root.get(p).isEmpty then outputQuery(p)
addFrontier(p)
incompletePairs.nonEmpty
def applyRule3(): Boolean =
val unidentifiedStates = frontier.keys
.filter(p => frontier(p).size >= 2)
.toSeq
for p <- unidentifiedStates do
for Seq(q1, q2) <- frontier(p).toSeq.combinations(2) do
val q1t = root.get(q1).get
val q2t = root.get(q2).get
val w = (q1t apart q2t).get
outputQuery(p ++ w)
unidentifiedStates.nonEmpty
def applyRule4(): Option[Mealy[S, I, O]] =
val h = buildHypothesis()
val w = checkConsistency(h) match
case Some(w) => w
case None =>
sul.findCounterExample(h) match
case Some(w) =>
println(s"counterexample=${w}")
if root.get(w).isEmpty then outputQuery(w)
var (t, r) = (root, h.initial)
val n = (0 until w.length).find: n =>
val i = w(n)
t = t.edges.get(i).get._2
r = h.transition((r, i))._2
val rt = root.get(r).get
(t apart rt).isDefined
w.slice(0, n.get + 1)
case None => return Some(h)
procCounterEx(h, w)
None
@tailrec
def procCounterEx(h: Mealy[S, I, O], w: Seq[I]): Unit =
println(s"h=${h}, w=${w}")
if basis.contains(w) || frontier.contains(w) then return
val v = frontier.keySet.find(s => w.startsWith(s)).get
val n = Math.floorDiv(w.length + v.length, 2)
val w1 = w.slice(0, n)
val w2 = w.slice(n, w.length)
val (_, r1) = h.run(w1)
val t1 = root.get(w1).get
val (_, r) = h.run(w)
val t = root.get(w).get
val x = (root.get(r).get apart t).get
outputQuery(r1 ++ w2 ++ x)
if (root.get(r1).get apart t1).isDefined then
procCounterEx(h, w1)
else
procCounterEx(h, r1 ++ w2)
@main
def main(): Unit =
var count = 0
val cache = mutable.Map.empty[Seq[Char], Seq[Int]]
val sul = new SUL[Char, Int]:
def check(is: Seq[Char]): Int =
if is.count(_ == '1') % 4 == 3 && is.containsSlice(Seq('0')) then 1 else 0
def trace(is: Seq[Char]): Seq[Int] =
cache.getOrElseUpdate(is, (1 to is.length).map: n =>
val is1 = is.slice(0, n)
check(is1))
def findCounterExample(h: Mealy[?, Char, Int]): Option[Seq[Char]] =
count += 1
(0 to 1024).iterator.map(_.toBinaryString.toSeq).find: w =>
val (os, _) = h.run(w)
os.last != check(w)
def inputAlphabet: Set[Char] = Set('0', '1')
def outputAlphabet: Set[Int] = Set(0, 1)
val learner = new LSharp[Char, Int](sul)
val h = learner.learn()
println(h)
println()
println(s"#MEMBER = ${cache.size}")
println(s"#EQUIV = ${count}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment