Last active
April 17, 2020 02:57
-
-
Save proger/e8fdd3ab4d45a8a34a2354fa3eac9928 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pi = nn.Parameter(torch.randn(len(zk), len(zk))) | |
# enforce the constraint that we never transfer | |
# to the start tag and we never transfer from the stop tag | |
pi.data[0, :] = -10000 | |
pi.data[:, len(zk)-1] = -10000 | |
pi | |
# pi[to, :] == pi[to] | |
# pi[:, from_] | |
def score(obs, zs): | |
to = pi[zs[1:]] | |
transitions = to.gather(1, zs[:-1].unsqueeze(1)) | |
return pi[zs[0],0] + transitions + pi[zk[STOP_TAG], zs[-1]] + obs.gather(1, zs.unsqueeze(1) | |
def decode(obs, pi=pi): | |
h0 = pi[0,0] * (torch.ones(len(zk)) - torch.eye(len(zk))[0]) | |
h_ = h0 # h_ as previous h | |
back = [] | |
for o in obs: | |
h, indices = (h_.expand_as(pi) + pi).max(1) | |
back.append(indices) | |
h_ = h + o | |
s, i = (h_ + pi[len(zk)-1]).max(0) | |
path = [i.item()] | |
for h_ in back[::-1]: | |
h_ = h_[path[-1]].item() | |
path.append(h_) | |
path.reverse() | |
assert path.pop(0) == 0 | |
return s, path | |
xs, _ = lstm_embed_sentence(s) | |
xs = xs.squeeze() | |
xs.shape | |
decode(xs), viterbi_decode(xs, pi=pi) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleInstances, UndecidableInstances, GeneralizedNewtypeDeriving #-} | |
-- stack runhaskell hmm.hs | |
-- enumeration posterior inference for HMMs (aka HMM evaluation or scoring) | |
-- thanks, robots: https://www.robots.ox.ac.uk/~vgg/rg/slides/hmm.pdf | |
class Finite a where | |
every :: [a] | |
forevery :: Finite a => (a -> b) -> [b] | |
forevery f = map f every | |
instance (Enum a, Bounded a) => Finite a where | |
every = [minBound..maxBound] | |
instance {-# OVERLAPPING #-} (Finite a, Finite b) => Finite (a, b) where | |
every = [(x,y) | x <- every, y <- every] | |
type P = Double | |
type Dist a = a -> P | |
pmf :: Finite a => Dist a -> [(a, P)] | |
pmf mass = forevery (\a -> (a, mass a)) | |
-- * P(x,θ) = P(x|θ)P(θ) | |
(>>>) :: Dist θ -> (θ -> Dist x) -> Dist (x, θ) | |
mar >>> cond = \(x, θ) -> (cond θ) x * mar θ | |
-- * P(x,θ) = P(x)P(θ) | |
q *:* p = q >>> const p | |
infixr 1 *:* | |
-- * P(a) = Σ_b P(a,b) | |
sum_ :: Finite b => Dist (a, b) -> Dist a | |
sum_ joint = \a -> sum (forevery (\b -> joint (a, b))) | |
-- * P(b|a) = P(a,b)/P(a) | |
bayes :: Finite b => Dist (a, b) -> a -> Dist b | |
bayes joint = \a -> \b -> joint (a, b) / (sum_ joint) a | |
newtype Obs = X Bool deriving (Show, Enum, Bounded) | |
newtype Latent = Z Bool deriving (Show, Enum, Bounded) | |
init_ :: Dist Latent | |
init_ (Z True) = 1 | |
init_ (Z False) = 0 | |
trans :: Latent -> Dist Latent | |
trans (Z True) (Z True) = 0.7 | |
trans (Z True) (Z False) = 0.3 | |
trans (Z False) (Z True) = 0.3 | |
trans (Z False) (Z False) = 0.7 | |
emit :: Latent -> Dist Obs | |
emit (Z True) (X True) = 0.9 | |
emit (Z True) (X False) = 0.1 | |
emit (Z False) (X True) = 0.1 | |
emit (Z False) (X False) = 0.9 | |
pObservation :: ((Latent, Latent), Latent) -> Dist ((Obs, Obs), Obs) | |
pObservation ((z1, z2), z3) = emit z1 *:* emit z2 *:* emit z3 | |
pSequence = init_ >>> \p -> trans p >>> trans | |
joint :: Dist (((Obs, Obs), Obs), ((Latent, Latent), Latent)) | |
joint = pSequence >>> pObservation | |
score = bayes joint ((X True, X True), X True) | |
main = mapM_ print (pmf score) | |
{- | |
*Main> :main | |
(((Z False,Z False),Z False),0.0) | |
(((Z False,Z False),Z True),4.929577464788734e-3) | |
(((Z False,Z True),Z False),0.0) | |
(((Z False,Z True),Z True),4.436619718309861e-2) | |
(((Z True,Z False),Z False),0.0) | |
(((Z True,Z False),Z True),1.901408450704226e-2) | |
(((Z True,Z True),Z False),0.0) | |
(((Z True,Z True),Z True),0.9316901408450704) | |
-} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// npm install -g webppl | |
// webppl hmm.js | |
// or paste to http://webppl.org, see below | |
var transition = function(s) { | |
return s ? flip(0.7) : flip(0.3); | |
}; | |
var observe = function(s) { | |
return s ? flip(0.9) : flip(0.1); | |
}; | |
var hmm = function(n) { | |
var prev = (n == 1) ? {states: [true], observations: []} : hmm(n - 1); | |
var newState = transition(prev.states[prev.states.length - 1]); | |
var newObs = observe(newState); | |
return { | |
states: prev.states.concat([newState]), | |
observations: prev.observations.concat([newObs]) | |
}; | |
}; | |
var trueObservations = [true, true]; | |
var dist = Infer({method: 'enumerate'}, function() { | |
var r = hmm(2); | |
factor(_.isEqual(r.observations, trueObservations) ? 0 : -Infinity); | |
return r.states; | |
}); | |
dist.getDist() | |
/// if you paste in webppl.org use: | |
// viz.table(dist) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment