Skip to content

Instantly share code, notes, and snippets.

@aavogt
Created August 5, 2025 20:03
Show Gist options
  • Select an option

  • Save aavogt/364725b73790fdec19944e9e76622654 to your computer and use it in GitHub Desktop.

Select an option

Save aavogt/364725b73790fdec19944e9e76622654 to your computer and use it in GitHub Desktop.
TMDA handwriting learning
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE RankNTypes #-}
{-# HLINT ignore "Use foldr" #-}
{-# HLINT ignore "Eta reduce" #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
module Learning where
import Common
import Data.DTW hiding (Path, ix)
import Data.Graph.Inductive.Query.SP (sp)
import Data.Graph.Inductive hiding ((&))
import qualified Data.HashSet as S
import Data.HashSet (HashSet)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Hashable
import Control.Monad.State
import Data.Monoid (Sum)
import Data.Foldable
import Data.Traversable
import Data.Maybe
import Data.Map (Map)
import Linear.V
import qualified Data.Vector as V
import Data.Proxy
import Linear
import Data.Coerce
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import Data.Semigroup
import Data.These
import Data.Align
import Control.Monad.Cont
import Control.Applicative
import Control.Monad.Trans.Maybe
import Debug.Trace
import Control.Monad.Trans.Writer
import Control.Lens.Unsound (adjoin)
import qualified Data.Map as M
import Data.IntSet (IntSet)
import Data.IntMap (IntMap)
import Data.Sequence (Seq)
import GHC.TypeLits (Nat)
import Data.Tuple (swap)
data TMDAS n = TMDAS {
_nqp :: IntMap (IntMap [(V n Double, [Node])]),
_perm :: IntMap [(V n Double, [Node])],
_queue :: Q n }
-- | queue that allows looking up the path ending on a Node
-- for now assume Q.getPath does not need to find Nodes in the middle of a path
data Q n = Q (Map (V n Double) [Node]) (IntMap (V n Double))
deriving (Eq, Show)
instance Semigroup (Q n) where
Q m s <> Q m' s' = Q (m <> m') (s <> s')
instance Monoid (Q n) where
mempty = Q mempty mempty
instance AsEmpty (Q n)
instance Monoid (TMDAS n) where
mempty = TMDAS mempty mempty mempty
instance Semigroup (TMDAS n) where
TMDAS x y z <> TMDAS a b c = TMDAS (x <> a) (y <> b) (z <> c)
makeLenses ''TMDAS
-- | fgl's 'spTree' has a Real constraint so this gives one that's good enough
newtype RealV v a = RealV (v a)
deriving (Num, Eq, Ord, Functor, Applicative)
makeWrapped ''RealV
instance (Real a, Finite v, Num (v a), Ord (v a)) => Real (RealV v a) where
toRational (RealV v) = toRational (V.head $ toVector $ toV v)
-- | dynamic time warping k nearest neighbours
-- the pen lifts are ignored -- Tracks -> Track is needed
-- satisfied by S.concat
dtwKNN :: (Eq a, Hashable a) => Track -> [(Track, a)] -- ^ db
-> [(Double,a)] -- ^ [(Distance, word)]
dtwKNN track db = db <&> _1 %~ cost . fastDtw dist shrinkMean 2 track
shiftTo00 :: Track -> Track
shiftTo00 ((a,b) :< xs) = (0,0) :< fmap (\(c,d) -> (c-a,d-b)) xs
dist :: P -> P -> Double
dist (a,b) (c,d) = sqrt $ (a-c)^2 + (c-d)^2
shrinkMean :: Track -> Track
shrinkMean ((a,b) :< ((c,d) :< cs)) = ((a+c) / 2, (b+d)/2) :< shrinkMean cs
shrinkMean x = x
-- | connect this phraseGraph to CorpusGrams.hs and tmda
-- how to display the candidates? How to control?
phraseGraph :: (Hashable a, Eq a, Monoid v)
=> w -- ^ weight for a missing 2-gram
-> w -- ^ weight for edges to the two Left Bool nodes
-> [[(v,a)]] -- ^ `knns`
-> HashMap a (HashMap a (Sum w)) -- ^ 2-grams (will be from module "CorpusGrams")
-> Gr (Either Bool (a, v, Int)) (v,w)
-- ^ Int is the index into `knns`, Bool is `start`, v is the cost to see the vertexes involved
-- and w is the cost of the edge
phraseGraph wMissing wEnds knns grams
= mkGraph (outerNodes <> innerNodes) (outerEdges <> innerEdges)
where
outerEdges = toListOf (_head . folded . _1 . to (0,, (mempty, wEnds)) <>
_last . folded . _1 . to (1,, (mempty, wEnds)))
nodes
outerNodes = [(0, Left True), (1, Left False)]
innerNodes = concat nodes <&> _2 %~ Right
nodes = evalState ?? 2 $
ifor knns \ i ss ->
for ss \ (nodeWt, s) -> do
j <- id <<+= 1
return (j, (s, nodeWt, i))
innerEdges = concat $ zipWith f nodes (drop 1 nodes)
f xs ys = [ (nx, ny, (wx<>wy, fromMaybe wMissing $ grams ^? ix lx . ix ly . _Wrapped))
| (nx, (lx, wx, _)) <- xs, (ny, (ly, wy, _)) <- ys ]
-- now I need to find the shortest path through the phraseGraph
pgSP1 :: Real w => Gr (Either Bool (a, Int)) w -> Maybe Path
pgSP1 gr = sp 0 1 gr
-- in what way are candidate paths distinguished?
-- perhaps I should replace HashSet with Map Double Text
-- this is a multiobjective: pick the word that fits the context
-- vs. the word that fits what was written
-- using that I will find
-- map dtwKNN
-- * multi objective A* search
-- | Targeted Multiobjective Dijkstra Algorithm https://arxiv.org/abs/2110.10978
-- | generates preprocessing orders (2.6)
gEPerm :: forall gr a n b. (Dim n, DynGraph gr) => gr a (V n b) -> [gr a (V n b)]
gEPerm gr = [ gr & emap \ (V x) -> V $ V.generate n \ j -> x V.! mod (i+j) n
| i <- [0 .. n-1]]
where n = reflectDim (Proxy :: Proxy n)
-- | dominance bound Beta and heuristic Pis
mkBetaPis :: (Fractional w, Real w, Dim n, Finite (V n)) => Gr a (V n w)
-> Node
-> Node
-> (V n w, IntMap (V n w))
mkBetaPis gr s t = grev gr
& gEPerm
& (coerce `asTypeOf` map (emap RealV))
<&> spTree t
<&> foldMap (IM.fromList . unLPath) -- if points are duplicated they should be equal
<&> fmap (\x -> (Max x, Min x))
& IM.unionsWith (<>)
<&> _1 %~ view (_Wrapped . _Wrapped . to (+1e-8)) -- what should epsilon be?
<&> _2 %~ view (_Wrapped . _Wrapped)
& \ m -> (m IM.! s & fst, fmap snd m)
ix2 i j = ix i . ix j
at2 i j = at i . non mempty . at j
-- | definition 1 and 8. a `dlt` b means 'a' is at least as good
-- everywhere and better in one place. Definition 8 uses the equal
-- sign so maybe the "better in one place" so I'm not sure about
-- dle vs dlt
dlt :: (Dim n, Ord a) => V n a -> V n a -> Bool
dlt x y = dle x y && or (liftA2 (<) x y)
dle :: (Dim n, Ord a) => V n a -> V n a -> Bool
dle x y = and (liftA2 (<=) x y)
qExtractMin :: MonadState (TMDAS n) m => m (Maybe (V n Double, [Node]))
qExtractMin =
queue %%= \ (Q m s) -> M.minViewWithKey m &
maybe (Nothing, Q m s) (\ ((k,a), m') -> (Just (k,a), Q m' (IM.delete (head a) s)))
qContains :: Node -> Q n -> Bool
qContains v (Q _ s) = IM.member v s
qDecreaseKey :: (V n Double, [Node]) -> Q n -> Maybe (Q n)
qDecreaseKey (c, p:ps) (Q m s) = s ^? ix p <&> \ k ->
Q (M.insert c (p:ps) $ M.delete k m) (IM.insert p c s)
qGetPath :: Node -> Q n -> Maybe (V n Double, [Node])
qGetPath n (Q m s) = do
k <- s ^? ix n
(k, ) <$> m ^? ix k
qInsert :: (V n Double, [Node]) -> Q n -> Q n
qInsert (k,n:ns) (Q m s) = Q (M.insert k (n:ns) m) (IM.insert n k s)
qInit :: Dim n => Node -> Q n
qInit n = Q (M.singleton 0 [n]) (IM.singleton n 0)
-- graph in figure 1
fig1gr :: Gr Char (V 2 Double)
fig1gr = mkGraph (zip [0 .. ] "stvw") $
[(0,1,V2 1 10),
(0,2,V2 1 1),
(0,3,V2 2 2),
(2,1,V2 2 4),
(2,3,V2 2 0),
(3,1,V2 1 2)]
<&> _3 %~ toV
-- seems to be okay
tmdaTest = tmda fig1gr 0 1
tmda :: forall gr (a :: *) (n :: Nat) t.
(Dim n)
=> Gr a (V n Double)
-> Node
-> Node
-> [(V n Double, [Node])]
tmda g s t = loop `evalState` (state0&queue .~ (qInit s :: Q n)) where
(beta, pis) = mkBetaPis g s t
state0 :: TMDAS n
state0 = mempty
loop = do
mp <- qExtractMin
case mp of
Just p@(_, v:_) -> do
perm . at v . non mempty %= (p:)
mnext <- nextQueuePath s t beta pis p (lsuc g v)
traverse_ (\next -> queue %= qInsert next) mnext
unless (v == t) $ do
permV <- use perm
for_ (propagateCandidates g s t permV beta pis p)
$ propagate pis
if v == t then (p :) <$> loop else loop
_ -> return []
propagate pis pnew@(cpw, w:v:_) = do
path <- uses queue (qGetPath w)
case path of
Nothing -> queue %= qInsert pnew
Just q@(cq, _:x:_) -> if cpw < cq then do
queue %= fromJust . qDecreaseKey pnew
nqp . at2 x w . non mempty %= (q :<)
else nqp . at2 v w . non mempty %= (:> pnew)
-- line 2,3,4 of propagate
propagateCandidates g s t perm beta pis (cpv,p @ (v:_)) =
[ (cpv + cvw, w : p) | (w, cvw) <- lsuc g v,
let cbpnew = cpv + cvw + pis IM.! w,
let bound = beta : perm ^.. ix t . folded . _1 ++
perm ^.. ix w . folded . _1,
not $ any (`dle` cbpnew) bound ]
-- at the very start this one returns nothing which is correct
nextQueuePath s t beta pis (cpStar, v:_) neighs = do
permVal <- use perm
let f (cpv, _) = not $ any (`dlt` (cpv + pis IM.! v))
(beta :
permVal ^.. ix t . folded . _1)
g pv@(cpv, _) = do
let cpsvle = anyOf (ix v . folded . _1) (`dlt` cpv) permVal
break = not $ cpsvle || cpStar `dlt` cpv
when break $ tell $ Just $ Min pv
return break
p' <- nqp . ixes (neighs ^.. folded . _1) . ix v %%= swap . runWriter . filterBreak f g
return (getMin <$> p')
nextQueuePath _ _ _ _ _ _ = undefined
ixes :: Ixed m => [Common.Index m] -> Traversal' m (IxValue m)
ixes (x:xs) = ix x `adjoin` ixes xs
ixes [] = ignored
filterBreak :: (AsEmpty as, Cons as as a a, Monad m) => (a -> Bool) -- ^ filter
-> (a -> m Bool) -- ^ True to stop traversing. Only applied to elements that pass the filter
-> as
-> m as
filterBreak f break (x :< xs) | f x = do
b <- break x
if b then return (x:<xs)
else (x :<) <$> filterBreak f break xs
| otherwise = filterBreak f break xs
filterBreak f break Empty = return Empty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment