Created
          April 30, 2022 09:07 
        
      - 
      
 - 
        
Save dvgodoy/e28238f718c38fb5270c8c4f027c45c8 to your computer and use it in GitHub Desktop.  
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class EncoderVar(nn.Module): | |
| def __init__(self, input_shape, z_size, base_model): | |
| super().__init__() | |
| self.z_size = z_size | |
| self.input_shape = input_shape | |
| self.base_model = base_model | |
| output_size = self.get_output_size() | |
| self.lin_mu = nn.Linear(output_size, z_size) | |
| self.lin_var = nn.Linear(output_size, z_size) | |
| def get_output_size(self): | |
| device = next(self.base_model.parameters()).device.type | |
| size = self.base_model(torch.zeros(1, *self.input_shape, device=device)).size(1) | |
| return size | |
| def kl_loss(self): | |
| kl_loss = -0.5*(1 + self.log_var - self.mu**2 - torch.exp(self.log_var)) | |
| return kl_loss | |
| def forward(self, x): | |
| # the base model, same as the traditional AE | |
| base_out = self.base_model(x) | |
| # now the encoder produces means (mu) using the lin_mu output layer | |
| # and log variances (log_var) using the lin_var output layer | |
| # we compute the standard deviation (std) from the log variance | |
| self.mu = self.lin_mu(base_out) | |
| self.log_var = self.lin_var(base_out) | |
| std = torch.exp(self.log_var/2) | |
| # that's the internal random input (epsilon) | |
| eps = torch.randn_like(self.mu) | |
| # and that's the z vector | |
| z = self.mu + eps * std | |
| return z | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment