Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Last active June 25, 2021 17:41
Show Gist options
  • Save Mason-McGough/7c1e31e96d6e8820475481b8d18bd39d to your computer and use it in GitHub Desktop.
Save Mason-McGough/7c1e31e96d6e8820475481b8d18bd39d to your computer and use it in GitHub Desktop.
GIRAFFE generator
# adapted from https://github.com/autonomousvision/giraffe (MIT License)
class Generator(nn.Module):
# ...
def get_latent_codes(self, batch_size=32, tmp=1.):
z_dim, z_dim_bg = self.z_dim, self.z_dim_bg
n_boxes = self.get_n_boxes()
def sample_z(x): return self.sample_z(x, tmp=tmp)
z_shape_obj = sample_z((batch_size, n_boxes, z_dim))
z_app_obj = sample_z((batch_size, n_boxes, z_dim))
z_shape_bg = sample_z((batch_size, z_dim_bg))
z_app_bg = sample_z((batch_size, z_dim_bg))
return z_shape_obj, z_app_obj, z_shape_bg, z_app_bg
def sample_z(self, size, to_device=True, tmp=1.):
z = torch.randn(*size) * tmp
if to_device:
z = z.to(self.device)
return z
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment