Skip to content

Instantly share code, notes, and snippets.

@proger
Last active April 17, 2020 02:57
Show Gist options
  • Save proger/e8fdd3ab4d45a8a34a2354fa3eac9928 to your computer and use it in GitHub Desktop.
Save proger/e8fdd3ab4d45a8a34a2354fa3eac9928 to your computer and use it in GitHub Desktop.
pi = nn.Parameter(torch.randn(len(zk), len(zk)))
# enforce the constraint that we never transfer
# to the start tag and we never transfer from the stop tag
pi.data[0, :] = -10000
pi.data[:, len(zk)-1] = -10000
pi
# pi[to, :] == pi[to]
# pi[:, from_]
def score(obs, zs):
to = pi[zs[1:]]
transitions = to.gather(1, zs[:-1].unsqueeze(1))
return pi[zs[0],0] + transitions + pi[zk[STOP_TAG], zs[-1]] + obs.gather(1, zs.unsqueeze(1)
def decode(obs, pi=pi):
h0 = pi[0,0] * (torch.ones(len(zk)) - torch.eye(len(zk))[0])
h_ = h0 # h_ as previous h
back = []
for o in obs:
h, indices = (h_.expand_as(pi) + pi).max(1)
back.append(indices)
h_ = h + o
s, i = (h_ + pi[len(zk)-1]).max(0)
path = [i.item()]
for h_ in back[::-1]:
h_ = h_[path[-1]].item()
path.append(h_)
path.reverse()
assert path.pop(0) == 0
return s, path
xs, _ = lstm_embed_sentence(s)
xs = xs.squeeze()
xs.shape
decode(xs), viterbi_decode(xs, pi=pi)
{-# LANGUAGE FlexibleInstances, UndecidableInstances, GeneralizedNewtypeDeriving #-}
-- stack runhaskell hmm.hs
-- enumeration posterior inference for HMMs (aka HMM evaluation or scoring)
-- thanks, robots: https://www.robots.ox.ac.uk/~vgg/rg/slides/hmm.pdf
class Finite a where
every :: [a]
forevery :: Finite a => (a -> b) -> [b]
forevery f = map f every
instance (Enum a, Bounded a) => Finite a where
every = [minBound..maxBound]
instance {-# OVERLAPPING #-} (Finite a, Finite b) => Finite (a, b) where
every = [(x,y) | x <- every, y <- every]
type P = Double
type Dist a = a -> P
pmf :: Finite a => Dist a -> [(a, P)]
pmf mass = forevery (\a -> (a, mass a))
-- * P(x,θ) = P(x|θ)P(θ)
(>>>) :: Dist θ -> (θ -> Dist x) -> Dist (x, θ)
mar >>> cond = \(x, θ) -> (cond θ) x * mar θ
-- * P(x,θ) = P(x)P(θ)
q *:* p = q >>> const p
infixr 1 *:*
-- * P(a) = Σ_b P(a,b)
sum_ :: Finite b => Dist (a, b) -> Dist a
sum_ joint = \a -> sum (forevery (\b -> joint (a, b)))
-- * P(b|a) = P(a,b)/P(a)
bayes :: Finite b => Dist (a, b) -> a -> Dist b
bayes joint = \a -> \b -> joint (a, b) / (sum_ joint) a
newtype Obs = X Bool deriving (Show, Enum, Bounded)
newtype Latent = Z Bool deriving (Show, Enum, Bounded)
init_ :: Dist Latent
init_ (Z True) = 1
init_ (Z False) = 0
trans :: Latent -> Dist Latent
trans (Z True) (Z True) = 0.7
trans (Z True) (Z False) = 0.3
trans (Z False) (Z True) = 0.3
trans (Z False) (Z False) = 0.7
emit :: Latent -> Dist Obs
emit (Z True) (X True) = 0.9
emit (Z True) (X False) = 0.1
emit (Z False) (X True) = 0.1
emit (Z False) (X False) = 0.9
pObservation :: ((Latent, Latent), Latent) -> Dist ((Obs, Obs), Obs)
pObservation ((z1, z2), z3) = emit z1 *:* emit z2 *:* emit z3
pSequence = init_ >>> \p -> trans p >>> trans
joint :: Dist (((Obs, Obs), Obs), ((Latent, Latent), Latent))
joint = pSequence >>> pObservation
score = bayes joint ((X True, X True), X True)
main = mapM_ print (pmf score)
{-
*Main> :main
(((Z False,Z False),Z False),0.0)
(((Z False,Z False),Z True),4.929577464788734e-3)
(((Z False,Z True),Z False),0.0)
(((Z False,Z True),Z True),4.436619718309861e-2)
(((Z True,Z False),Z False),0.0)
(((Z True,Z False),Z True),1.901408450704226e-2)
(((Z True,Z True),Z False),0.0)
(((Z True,Z True),Z True),0.9316901408450704)
-}
// npm install -g webppl
// webppl hmm.js
// or paste to http://webppl.org, see below
var transition = function(s) {
return s ? flip(0.7) : flip(0.3);
};
var observe = function(s) {
return s ? flip(0.9) : flip(0.1);
};
var hmm = function(n) {
var prev = (n == 1) ? {states: [true], observations: []} : hmm(n - 1);
var newState = transition(prev.states[prev.states.length - 1]);
var newObs = observe(newState);
return {
states: prev.states.concat([newState]),
observations: prev.observations.concat([newObs])
};
};
var trueObservations = [true, true];
var dist = Infer({method: 'enumerate'}, function() {
var r = hmm(2);
factor(_.isEqual(r.observations, trueObservations) ? 0 : -Infinity);
return r.states;
});
dist.getDist()
/// if you paste in webppl.org use:
// viz.table(dist)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment