Skip to content

Instantly share code, notes, and snippets.

@cavaunpeu
Created February 10, 2019 04:40
Show Gist options
  • Save cavaunpeu/cb570b9464f2a300f8430891e9af373f to your computer and use it in GitHub Desktop.
Save cavaunpeu/cb570b9464f2a300f8430891e9af373f to your computer and use it in GitHub Desktop.
Constrain transition matrix s.t. all rows sum to 1
# Ensure that p(~ | z_k) = 1, i.e. the probability of jumping from a given state to any other state is 1
self._A = nn.Parameter(A)
_A_triu = self._A.triu().detach()
A_triu = []
for i, row in enumerate(_A_triu):
nz = row[row.nonzero().view(-1)]
if i == 0:
probs = F.softmax(nz, dim=-1)
else:
probs = nz.exp() * (1 - torch.stack(A_triu)[:, i].sum()) / nz.exp().sum()
row = torch.cat([torch.zeros(i), probs])
A_triu.append(row)
A_triu = torch.stack(A_triu)
self.A = A_triu.t() + A_triu - torch.diag(A_triu.diag())
assert (self.A.sum(1) == 1).all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment