Last active
June 25, 2021 17:41
-
-
Save Mason-McGough/7c1e31e96d6e8820475481b8d18bd39d to your computer and use it in GitHub Desktop.
GIRAFFE generator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# adapted from https://github.com/autonomousvision/giraffe (MIT License) | |
class Generator(nn.Module): | |
# ... | |
def get_latent_codes(self, batch_size=32, tmp=1.): | |
z_dim, z_dim_bg = self.z_dim, self.z_dim_bg | |
n_boxes = self.get_n_boxes() | |
def sample_z(x): return self.sample_z(x, tmp=tmp) | |
z_shape_obj = sample_z((batch_size, n_boxes, z_dim)) | |
z_app_obj = sample_z((batch_size, n_boxes, z_dim)) | |
z_shape_bg = sample_z((batch_size, z_dim_bg)) | |
z_app_bg = sample_z((batch_size, z_dim_bg)) | |
return z_shape_obj, z_app_obj, z_shape_bg, z_app_bg | |
def sample_z(self, size, to_device=True, tmp=1.): | |
z = torch.randn(*size) * tmp | |
if to_device: | |
z = z.to(self.device) | |
return z | |
# ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment