Created
August 15, 2024 12:18
-
-
Save makenowjust/924e04c33d9760b6194101e4ce26de48 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.util.boundary | |
import scala.collection.mutable | |
import scala.compiletime.ops.double | |
import scala.annotation.tailrec | |
final case class Mealy[Q, I, O]( | |
initial: Q, | |
transition: Map[(Q, I), (O, Q)] | |
): | |
type State = Q | |
def run(is: Seq[I]): (Seq[O], Q) = | |
val os = Seq.newBuilder[O] | |
var q = initial | |
for i <- is do | |
val (o, q1) = transition((q, i)) | |
os.addOne(o) | |
q = q1 | |
(os.result(), q) | |
trait SUL[I, O]: | |
def trace(is: Seq[I]): Seq[O] | |
def findCounterExample(h: Mealy[?, I, O]): Option[Seq[I]] | |
def inputAlphabet: Set[I] | |
def outputAlphabet: Set[O] | |
final case class OTree[I, O]( | |
edges: Map[I, (O, OTree[I, O])] = Map.empty[I, (O, OTree[I, O])] | |
): | |
def get(is: Seq[I]): Option[OTree[I, O]] = | |
is.headOption match | |
case None => Some(this) | |
case Some(i) => edges.get(i).flatMap(_._2.get(is.tail)) | |
def inserted(ios: Seq[(I, O)]): OTree[I, O] = | |
ios.headOption match | |
case None => this | |
case Some((i, o)) => | |
edges.get(i) match | |
case None => | |
val t = OTree(Map.empty).inserted(ios.tail) | |
OTree(edges ++ Map(i -> (o, t))) | |
case Some((_, t)) => | |
OTree(edges ++ Map(i -> (o, t.inserted(ios.tail)))) | |
infix def apart(that: OTree[I, O]): Option[Seq[I]] = | |
val queue = mutable.Queue.empty[(Seq[I], OTree[I, O], OTree[I, O])] | |
queue.enqueue((Seq.empty, this, that)) | |
boundary: | |
while queue.nonEmpty do | |
val (is, t1, t2) = queue.dequeue | |
for | |
i <- t1.edges.keySet ++ t2.edges.keySet | |
((o1, t1), (o2, t2)) <- t1.edges.get(i) zip t2.edges.get(i) | |
do | |
if o1 != o2 then boundary.break(Some(is ++ Seq(i))) | |
else queue.enqueue((is ++ Seq(i), t1, t2)) | |
None | |
final class LSharp[I, O](val sul: SUL[I, O]): | |
type S = Seq[I] | |
var root: OTree[I, O] = OTree() | |
val basis = mutable.Set.empty[S] | |
val frontier = mutable.Map.empty[S, Set[S]] | |
def learn(): Mealy[S, I, O] = | |
basis.add(Seq.empty) | |
while true do | |
var updated = false | |
println("Update frontier") | |
updateFrontier() | |
println(s"basis=${basis}, frontier=${frontier}") | |
println("Apply Rule1") | |
updated = applyRule1() | |
if updated then println("Rule1 is applied.") | |
if !updated then | |
println("Apply Rule2") | |
updated = applyRule2() | |
if updated then println("Rule2 is applied.") | |
if !updated then | |
println("Apply Rule3") | |
updated = applyRule3() | |
if updated then println("Rule3 is applied.") | |
if !updated then | |
println("Apply Rule4") | |
applyRule4() match | |
case Some(h) => return h | |
case None => () | |
sys.error("unreachable") | |
def outputQuery(is: Seq[I]): Unit = | |
val os = sul.trace(is) | |
root = root.inserted(is.zip(os)) | |
def addBasis(q: S): Unit = | |
println(s"Add ${q} to basis") | |
basis.add(q) | |
val qt = root.get(q).get | |
frontier.mapValuesInPlace: (p, qs) => | |
val pt = root.get(p).get | |
if (qt apart pt).isEmpty then qs ++ Set(q) | |
else qs | |
def addFrontier(p: S): Unit = | |
val pt = root.get(p).get | |
val qs = basis.iterator.filter: q => | |
val qt = root.get(q).get | |
(qt apart pt).isEmpty | |
frontier(p) = qs.toSet | |
def updateFrontier(): Unit = | |
frontier.mapValuesInPlace: (p, qs) => | |
val pt = root.get(p).get | |
qs.filter: q => | |
val qt = root.get(q).get | |
(qt apart pt).isEmpty | |
def buildHypothesis(): Mealy[S, I, O] = | |
val transition = Map.newBuilder[(S, I), (O, S)] | |
val initial = Seq.empty[I] | |
for q <- basis do | |
val qt = root.get(q).get | |
for i <- sul.inputAlphabet do | |
val (o, _) = qt.edges(i) | |
var p = q ++ Seq(i) | |
if !basis.contains(p) then | |
p = frontier(p).head | |
transition.addOne((q, i), (o, p)) | |
Mealy(initial, transition.result()) | |
def checkConsistency(h: Mealy[S, I, O]): Option[S] = | |
val queue = mutable.Queue.empty[(S, OTree[I, O], S)] | |
queue.enqueue((Seq.empty, root, h.initial)) | |
while queue.nonEmpty do | |
val (s, t, q) = queue.dequeue() | |
val qt = root.get(q).get | |
(t apart qt) match | |
case None => | |
for (i, (_, t1)) <- t.edges do | |
queue.addOne((s ++ Seq(i), t1, h.transition((q, i))._2)) | |
case Some(_) => return Some(s) | |
None | |
def applyRule1(): Boolean = | |
val isolatedStates = frontier.iterator | |
.filter(_._2.isEmpty) | |
.map(_._1) | |
.toSeq | |
// Note that this `frontier(q).isEmpty` is necessary | |
// because `frontier` can be updated on this loop and | |
// new non-isolated states can be appeared. | |
for q <- isolatedStates; if frontier(q).isEmpty do | |
frontier.remove(q) | |
addBasis(q) | |
isolatedStates.nonEmpty | |
def applyRule2(): Boolean = | |
val incompletePairs = basis.iterator | |
.flatMap(q => sul.inputAlphabet.map((q, _))) | |
.filter: (q, c) => | |
val p = q ++ Seq(c) | |
root.get(p).isEmpty || !basis.contains(p) && !frontier.contains(p) | |
.toSeq | |
for (q, c) <- incompletePairs do | |
val p = q ++ Seq(c) | |
if root.get(p).isEmpty then outputQuery(p) | |
addFrontier(p) | |
incompletePairs.nonEmpty | |
def applyRule3(): Boolean = | |
val unidentifiedStates = frontier.keys | |
.filter(p => frontier(p).size >= 2) | |
.toSeq | |
for p <- unidentifiedStates do | |
for Seq(q1, q2) <- frontier(p).toSeq.combinations(2) do | |
val q1t = root.get(q1).get | |
val q2t = root.get(q2).get | |
val w = (q1t apart q2t).get | |
outputQuery(p ++ w) | |
unidentifiedStates.nonEmpty | |
def applyRule4(): Option[Mealy[S, I, O]] = | |
val h = buildHypothesis() | |
val w = checkConsistency(h) match | |
case Some(w) => w | |
case None => | |
sul.findCounterExample(h) match | |
case Some(w) => | |
println(s"counterexample=${w}") | |
if root.get(w).isEmpty then outputQuery(w) | |
var (t, r) = (root, h.initial) | |
val n = (0 until w.length).find: n => | |
val i = w(n) | |
t = t.edges.get(i).get._2 | |
r = h.transition((r, i))._2 | |
val rt = root.get(r).get | |
(t apart rt).isDefined | |
w.slice(0, n.get + 1) | |
case None => return Some(h) | |
procCounterEx(h, w) | |
None | |
@tailrec | |
def procCounterEx(h: Mealy[S, I, O], w: Seq[I]): Unit = | |
println(s"h=${h}, w=${w}") | |
if basis.contains(w) || frontier.contains(w) then return | |
val v = frontier.keySet.find(s => w.startsWith(s)).get | |
val n = Math.floorDiv(w.length + v.length, 2) | |
val w1 = w.slice(0, n) | |
val w2 = w.slice(n, w.length) | |
val (_, r1) = h.run(w1) | |
val t1 = root.get(w1).get | |
val (_, r) = h.run(w) | |
val t = root.get(w).get | |
val x = (root.get(r).get apart t).get | |
outputQuery(r1 ++ w2 ++ x) | |
if (root.get(r1).get apart t1).isDefined then | |
procCounterEx(h, w1) | |
else | |
procCounterEx(h, r1 ++ w2) | |
@main | |
def main(): Unit = | |
var count = 0 | |
val cache = mutable.Map.empty[Seq[Char], Seq[Int]] | |
val sul = new SUL[Char, Int]: | |
def check(is: Seq[Char]): Int = | |
if is.count(_ == '1') % 4 == 3 && is.containsSlice(Seq('0')) then 1 else 0 | |
def trace(is: Seq[Char]): Seq[Int] = | |
cache.getOrElseUpdate(is, (1 to is.length).map: n => | |
val is1 = is.slice(0, n) | |
check(is1)) | |
def findCounterExample(h: Mealy[?, Char, Int]): Option[Seq[Char]] = | |
count += 1 | |
(0 to 1024).iterator.map(_.toBinaryString.toSeq).find: w => | |
val (os, _) = h.run(w) | |
os.last != check(w) | |
def inputAlphabet: Set[Char] = Set('0', '1') | |
def outputAlphabet: Set[Int] = Set(0, 1) | |
val learner = new LSharp[Char, Int](sul) | |
val h = learner.learn() | |
println(h) | |
println() | |
println(s"#MEMBER = ${cache.size}") | |
println(s"#EQUIV = ${count}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment