Last active
          April 30, 2022 08:45 
        
      - 
      
 - 
        
Save dvgodoy/191843b6cb68c947074bc03b96453a35 to your computer and use it in GitHub Desktop.  
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch.nn as nn | |
| def set_seed(self, seed=42): | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| class Encoder(nn.Module): | |
| def __init__(self, input_shape, z_size, base_model): | |
| super().__init__() | |
| self.input_shape = input_shape | |
| self.z_size = z_size | |
| self.base_model = base_model | |
| # appends the "lin_latent" linear layer to map from "output_size" | |
| # given by the base model to desired size of the representation (z_size) | |
| output_size = self._get_output_size() | |
| self.lin_latent = nn.Linear(output_size, z_size) | |
| def _get_output_size(self): | |
| # builds a dummy batch containing one dummy tensor | |
| # full of zeroes with the same shape as the inputs | |
| device = next(self.base_model.parameters()).device.type | |
| dummy = torch.zeros(1, *self.input_shape, device=device) | |
| # sends the dummy batch through the base model to get | |
| # the output size produced by it | |
| size = self.base_model(dummy).size(1) | |
| return size | |
| def forward(self, x): | |
| # forwards the input through the base model and then the "lin_latent" layer | |
| # to get the representation (z) | |
| base_out = self.base_model(x) | |
| out = self.lin_latent(base_out) | |
| return out | |
| set_seed(13) | |
| # we defined our representation (z) as a vector of size one | |
| z_size = 1 | |
| # our images are 1@28x28 | |
| input_shape = (1, 28, 28) # (C, H, W) | |
| base_model = nn.Sequential( | |
| # (C, H, W) -> C*H*W | |
| nn.Flatten(), | |
| # C*H*W -> 2048 | |
| nn.Linear(np.prod(input_shape), 2048), | |
| nn.LeakyReLU(), | |
| # 2048 -> 2048 | |
| nn.Linear(2048, 2048), | |
| nn.LeakyReLU(), | |
| ) | |
| encoder = Encoder(input_shape, z_size, base_model) | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment