Created
August 5, 2025 20:03
-
-
Save aavogt/364725b73790fdec19944e9e76622654 to your computer and use it in GitHub Desktop.
TMDA handwriting learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| {-# LANGUAGE BlockArguments #-} | |
| {-# LANGUAGE TupleSections #-} | |
| {-# LANGUAGE TypeFamilies #-} | |
| {-# LANGUAGE FlexibleContexts #-} | |
| {-# LANGUAGE ScopedTypeVariables #-} | |
| {-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
| {-# LANGUAGE TemplateHaskell #-} | |
| {-# LANGUAGE FlexibleInstances #-} | |
| {-# LANGUAGE MultiParamTypeClasses #-} | |
| {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} | |
| {-# LANGUAGE RankNTypes #-} | |
| {-# HLINT ignore "Use foldr" #-} | |
| {-# HLINT ignore "Eta reduce" #-} | |
| {-# OPTIONS_GHC -Wno-incomplete-patterns #-} | |
| {-# LANGUAGE ViewPatterns #-} | |
| {-# LANGUAGE DataKinds #-} | |
| {-# LANGUAGE PolyKinds #-} | |
| module Learning where | |
| import Common | |
| import Data.DTW hiding (Path, ix) | |
| import Data.Graph.Inductive.Query.SP (sp) | |
| import Data.Graph.Inductive hiding ((&)) | |
| import qualified Data.HashSet as S | |
| import Data.HashSet (HashSet) | |
| import Data.HashMap.Strict (HashMap) | |
| import qualified Data.HashMap.Strict as HM | |
| import Data.Hashable | |
| import Control.Monad.State | |
| import Data.Monoid (Sum) | |
| import Data.Foldable | |
| import Data.Traversable | |
| import Data.Maybe | |
| import Data.Map (Map) | |
| import Linear.V | |
| import qualified Data.Vector as V | |
| import Data.Proxy | |
| import Linear | |
| import Data.Coerce | |
| import qualified Data.IntMap as IM | |
| import qualified Data.IntSet as IS | |
| import Data.Semigroup | |
| import Data.These | |
| import Data.Align | |
| import Control.Monad.Cont | |
| import Control.Applicative | |
| import Control.Monad.Trans.Maybe | |
| import Debug.Trace | |
| import Control.Monad.Trans.Writer | |
| import Control.Lens.Unsound (adjoin) | |
| import qualified Data.Map as M | |
| import Data.IntSet (IntSet) | |
| import Data.IntMap (IntMap) | |
| import Data.Sequence (Seq) | |
| import GHC.TypeLits (Nat) | |
| import Data.Tuple (swap) | |
| data TMDAS n = TMDAS { | |
| _nqp :: IntMap (IntMap [(V n Double, [Node])]), | |
| _perm :: IntMap [(V n Double, [Node])], | |
| _queue :: Q n } | |
| -- | queue that allows looking up the path ending on a Node | |
| -- for now assume Q.getPath does not need to find Nodes in the middle of a path | |
| data Q n = Q (Map (V n Double) [Node]) (IntMap (V n Double)) | |
| deriving (Eq, Show) | |
| instance Semigroup (Q n) where | |
| Q m s <> Q m' s' = Q (m <> m') (s <> s') | |
| instance Monoid (Q n) where | |
| mempty = Q mempty mempty | |
| instance AsEmpty (Q n) | |
| instance Monoid (TMDAS n) where | |
| mempty = TMDAS mempty mempty mempty | |
| instance Semigroup (TMDAS n) where | |
| TMDAS x y z <> TMDAS a b c = TMDAS (x <> a) (y <> b) (z <> c) | |
| makeLenses ''TMDAS | |
| -- | fgl's 'spTree' has a Real constraint so this gives one that's good enough | |
| newtype RealV v a = RealV (v a) | |
| deriving (Num, Eq, Ord, Functor, Applicative) | |
| makeWrapped ''RealV | |
| instance (Real a, Finite v, Num (v a), Ord (v a)) => Real (RealV v a) where | |
| toRational (RealV v) = toRational (V.head $ toVector $ toV v) | |
| -- | dynamic time warping k nearest neighbours | |
| -- the pen lifts are ignored -- Tracks -> Track is needed | |
| -- satisfied by S.concat | |
| dtwKNN :: (Eq a, Hashable a) => Track -> [(Track, a)] -- ^ db | |
| -> [(Double,a)] -- ^ [(Distance, word)] | |
| dtwKNN track db = db <&> _1 %~ cost . fastDtw dist shrinkMean 2 track | |
| shiftTo00 :: Track -> Track | |
| shiftTo00 ((a,b) :< xs) = (0,0) :< fmap (\(c,d) -> (c-a,d-b)) xs | |
| dist :: P -> P -> Double | |
| dist (a,b) (c,d) = sqrt $ (a-c)^2 + (c-d)^2 | |
| shrinkMean :: Track -> Track | |
| shrinkMean ((a,b) :< ((c,d) :< cs)) = ((a+c) / 2, (b+d)/2) :< shrinkMean cs | |
| shrinkMean x = x | |
| -- | connect this phraseGraph to CorpusGrams.hs and tmda | |
| -- how to display the candidates? How to control? | |
| phraseGraph :: (Hashable a, Eq a, Monoid v) | |
| => w -- ^ weight for a missing 2-gram | |
| -> w -- ^ weight for edges to the two Left Bool nodes | |
| -> [[(v,a)]] -- ^ `knns` | |
| -> HashMap a (HashMap a (Sum w)) -- ^ 2-grams (will be from module "CorpusGrams") | |
| -> Gr (Either Bool (a, v, Int)) (v,w) | |
| -- ^ Int is the index into `knns`, Bool is `start`, v is the cost to see the vertexes involved | |
| -- and w is the cost of the edge | |
| phraseGraph wMissing wEnds knns grams | |
| = mkGraph (outerNodes <> innerNodes) (outerEdges <> innerEdges) | |
| where | |
| outerEdges = toListOf (_head . folded . _1 . to (0,, (mempty, wEnds)) <> | |
| _last . folded . _1 . to (1,, (mempty, wEnds))) | |
| nodes | |
| outerNodes = [(0, Left True), (1, Left False)] | |
| innerNodes = concat nodes <&> _2 %~ Right | |
| nodes = evalState ?? 2 $ | |
| ifor knns \ i ss -> | |
| for ss \ (nodeWt, s) -> do | |
| j <- id <<+= 1 | |
| return (j, (s, nodeWt, i)) | |
| innerEdges = concat $ zipWith f nodes (drop 1 nodes) | |
| f xs ys = [ (nx, ny, (wx<>wy, fromMaybe wMissing $ grams ^? ix lx . ix ly . _Wrapped)) | |
| | (nx, (lx, wx, _)) <- xs, (ny, (ly, wy, _)) <- ys ] | |
| -- now I need to find the shortest path through the phraseGraph | |
| pgSP1 :: Real w => Gr (Either Bool (a, Int)) w -> Maybe Path | |
| pgSP1 gr = sp 0 1 gr | |
| -- in what way are candidate paths distinguished? | |
| -- perhaps I should replace HashSet with Map Double Text | |
| -- this is a multiobjective: pick the word that fits the context | |
| -- vs. the word that fits what was written | |
| -- using that I will find | |
| -- map dtwKNN | |
| -- * multi objective A* search | |
| -- | Targeted Multiobjective Dijkstra Algorithm https://arxiv.org/abs/2110.10978 | |
| -- | generates preprocessing orders (2.6) | |
| gEPerm :: forall gr a n b. (Dim n, DynGraph gr) => gr a (V n b) -> [gr a (V n b)] | |
| gEPerm gr = [ gr & emap \ (V x) -> V $ V.generate n \ j -> x V.! mod (i+j) n | |
| | i <- [0 .. n-1]] | |
| where n = reflectDim (Proxy :: Proxy n) | |
| -- | dominance bound Beta and heuristic Pis | |
| mkBetaPis :: (Fractional w, Real w, Dim n, Finite (V n)) => Gr a (V n w) | |
| -> Node | |
| -> Node | |
| -> (V n w, IntMap (V n w)) | |
| mkBetaPis gr s t = grev gr | |
| & gEPerm | |
| & (coerce `asTypeOf` map (emap RealV)) | |
| <&> spTree t | |
| <&> foldMap (IM.fromList . unLPath) -- if points are duplicated they should be equal | |
| <&> fmap (\x -> (Max x, Min x)) | |
| & IM.unionsWith (<>) | |
| <&> _1 %~ view (_Wrapped . _Wrapped . to (+1e-8)) -- what should epsilon be? | |
| <&> _2 %~ view (_Wrapped . _Wrapped) | |
| & \ m -> (m IM.! s & fst, fmap snd m) | |
| ix2 i j = ix i . ix j | |
| at2 i j = at i . non mempty . at j | |
| -- | definition 1 and 8. a `dlt` b means 'a' is at least as good | |
| -- everywhere and better in one place. Definition 8 uses the equal | |
| -- sign so maybe the "better in one place" so I'm not sure about | |
| -- dle vs dlt | |
| dlt :: (Dim n, Ord a) => V n a -> V n a -> Bool | |
| dlt x y = dle x y && or (liftA2 (<) x y) | |
| dle :: (Dim n, Ord a) => V n a -> V n a -> Bool | |
| dle x y = and (liftA2 (<=) x y) | |
| qExtractMin :: MonadState (TMDAS n) m => m (Maybe (V n Double, [Node])) | |
| qExtractMin = | |
| queue %%= \ (Q m s) -> M.minViewWithKey m & | |
| maybe (Nothing, Q m s) (\ ((k,a), m') -> (Just (k,a), Q m' (IM.delete (head a) s))) | |
| qContains :: Node -> Q n -> Bool | |
| qContains v (Q _ s) = IM.member v s | |
| qDecreaseKey :: (V n Double, [Node]) -> Q n -> Maybe (Q n) | |
| qDecreaseKey (c, p:ps) (Q m s) = s ^? ix p <&> \ k -> | |
| Q (M.insert c (p:ps) $ M.delete k m) (IM.insert p c s) | |
| qGetPath :: Node -> Q n -> Maybe (V n Double, [Node]) | |
| qGetPath n (Q m s) = do | |
| k <- s ^? ix n | |
| (k, ) <$> m ^? ix k | |
| qInsert :: (V n Double, [Node]) -> Q n -> Q n | |
| qInsert (k,n:ns) (Q m s) = Q (M.insert k (n:ns) m) (IM.insert n k s) | |
| qInit :: Dim n => Node -> Q n | |
| qInit n = Q (M.singleton 0 [n]) (IM.singleton n 0) | |
| -- graph in figure 1 | |
| fig1gr :: Gr Char (V 2 Double) | |
| fig1gr = mkGraph (zip [0 .. ] "stvw") $ | |
| [(0,1,V2 1 10), | |
| (0,2,V2 1 1), | |
| (0,3,V2 2 2), | |
| (2,1,V2 2 4), | |
| (2,3,V2 2 0), | |
| (3,1,V2 1 2)] | |
| <&> _3 %~ toV | |
| -- seems to be okay | |
| tmdaTest = tmda fig1gr 0 1 | |
| tmda :: forall gr (a :: *) (n :: Nat) t. | |
| (Dim n) | |
| => Gr a (V n Double) | |
| -> Node | |
| -> Node | |
| -> [(V n Double, [Node])] | |
| tmda g s t = loop `evalState` (state0&queue .~ (qInit s :: Q n)) where | |
| (beta, pis) = mkBetaPis g s t | |
| state0 :: TMDAS n | |
| state0 = mempty | |
| loop = do | |
| mp <- qExtractMin | |
| case mp of | |
| Just p@(_, v:_) -> do | |
| perm . at v . non mempty %= (p:) | |
| mnext <- nextQueuePath s t beta pis p (lsuc g v) | |
| traverse_ (\next -> queue %= qInsert next) mnext | |
| unless (v == t) $ do | |
| permV <- use perm | |
| for_ (propagateCandidates g s t permV beta pis p) | |
| $ propagate pis | |
| if v == t then (p :) <$> loop else loop | |
| _ -> return [] | |
| propagate pis pnew@(cpw, w:v:_) = do | |
| path <- uses queue (qGetPath w) | |
| case path of | |
| Nothing -> queue %= qInsert pnew | |
| Just q@(cq, _:x:_) -> if cpw < cq then do | |
| queue %= fromJust . qDecreaseKey pnew | |
| nqp . at2 x w . non mempty %= (q :<) | |
| else nqp . at2 v w . non mempty %= (:> pnew) | |
| -- line 2,3,4 of propagate | |
| propagateCandidates g s t perm beta pis (cpv,p @ (v:_)) = | |
| [ (cpv + cvw, w : p) | (w, cvw) <- lsuc g v, | |
| let cbpnew = cpv + cvw + pis IM.! w, | |
| let bound = beta : perm ^.. ix t . folded . _1 ++ | |
| perm ^.. ix w . folded . _1, | |
| not $ any (`dle` cbpnew) bound ] | |
| -- at the very start this one returns nothing which is correct | |
| nextQueuePath s t beta pis (cpStar, v:_) neighs = do | |
| permVal <- use perm | |
| let f (cpv, _) = not $ any (`dlt` (cpv + pis IM.! v)) | |
| (beta : | |
| permVal ^.. ix t . folded . _1) | |
| g pv@(cpv, _) = do | |
| let cpsvle = anyOf (ix v . folded . _1) (`dlt` cpv) permVal | |
| break = not $ cpsvle || cpStar `dlt` cpv | |
| when break $ tell $ Just $ Min pv | |
| return break | |
| p' <- nqp . ixes (neighs ^.. folded . _1) . ix v %%= swap . runWriter . filterBreak f g | |
| return (getMin <$> p') | |
| nextQueuePath _ _ _ _ _ _ = undefined | |
| ixes :: Ixed m => [Common.Index m] -> Traversal' m (IxValue m) | |
| ixes (x:xs) = ix x `adjoin` ixes xs | |
| ixes [] = ignored | |
| filterBreak :: (AsEmpty as, Cons as as a a, Monad m) => (a -> Bool) -- ^ filter | |
| -> (a -> m Bool) -- ^ True to stop traversing. Only applied to elements that pass the filter | |
| -> as | |
| -> m as | |
| filterBreak f break (x :< xs) | f x = do | |
| b <- break x | |
| if b then return (x:<xs) | |
| else (x :<) <$> filterBreak f break xs | |
| | otherwise = filterBreak f break xs | |
| filterBreak f break Empty = return Empty | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment