Created
          February 10, 2019 04:40 
        
      - 
      
- 
        Save cavaunpeu/cb570b9464f2a300f8430891e9af373f to your computer and use it in GitHub Desktop. 
    Constrain transition matrix s.t. all rows sum to 1
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # Ensure that p(~ | z_k) = 1, i.e. the probability of jumping from a given state to any other state is 1 | |
| self._A = nn.Parameter(A) | |
| _A_triu = self._A.triu().detach() | |
| A_triu = [] | |
| for i, row in enumerate(_A_triu): | |
| nz = row[row.nonzero().view(-1)] | |
| if i == 0: | |
| probs = F.softmax(nz, dim=-1) | |
| else: | |
| probs = nz.exp() * (1 - torch.stack(A_triu)[:, i].sum()) / nz.exp().sum() | |
| row = torch.cat([torch.zeros(i), probs]) | |
| A_triu.append(row) | |
| A_triu = torch.stack(A_triu) | |
| self.A = A_triu.t() + A_triu - torch.diag(A_triu.diag()) | |
| assert (self.A.sum(1) == 1).all() | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment