This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __iter__(self): | |
| worker_info = torch.utils.data.get_worker_info() | |
| if worker_info is not None: | |
| self.random.seed(round(1000 * random.random())) | |
| length = ceil(self.length // worker_info.num_workers) | |
| if worker_info.id == worker_info.num_workers - 1: | |
| length = self.length - length * (worker_info.num_workers - 1) | |
| assert length > 0 | |
| self.ilength = length |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class VGG19Encoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.feats = nn.Sequential(*self.extract_vgg19_pretrained_layers()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def extract_vgg19_pretrained_layers(self): | |
| # get pretrained model | |
| features = torchvision.models.vgg19_bn(pretrained=True, progress=True).features | |
| # change to reflection | |
| for i in features: | |
| if isinstance(i, nn.Conv2d): | |
| i.padding_mode = 'reflect' | |
| # get blocks of layers we want |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def forward(self, x): | |
| outputs = [] | |
| for group in self.feats: | |
| x = group(x) | |
| outputs.append(x) | |
| return outputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def main(): | |
| encoder = VGG19Encoder() | |
| print(encoder) | |
| encoder.freeze() | |
| sample_input = torch.ones((1, 3, 256, 256)) | |
| sample_output = encoder(sample_input) | |
| print(f"Input shapes: {sample_input.shape}") | |
| print(f"Output shapes: {[i.shape for i in sample_output]}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Decoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| features = torchvision.models.vgg19(pretrained=True, progress=True).features[20:None:-1] | |
| for i, layer in enumerate(features): | |
| if isinstance(layer, nn.MaxPool2d): | |
| features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest') | |
| elif isinstance(layer, nn.Conv2d): | |
| conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| for i, layer in enumerate(features): | |
| if isinstance(layer, nn.MaxPool2d): | |
| features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest') | |
| elif isinstance(layer, nn.Conv2d): | |
| conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \ | |
| kernel_size = layer.kernel_size, stride = layer.stride, \ | |
| padding = layer.padding, padding_mode = 'reflect') | |
| with torch.no_grad(): | |
| conv2d.weight[...] = layer.weight.transpose(0, 1) | |
| features[i] = conv2d |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def main(): | |
| encoder = VGG19Encoder() | |
| decoder = Decoder() | |
| print(decoder) | |
| sample_input = torch.ones((1, 3, 256, 256)) | |
| outputs = encoder(sample_input) | |
| output = decoder(outputs[-1]) | |
| print(f"Input shape: {sample_input.shape}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train_epoch_reconstruct(encoder, decoder, dataloader, optimizer, epoch_num, writer, run): | |
| encoder.train() | |
| decoder.train() | |
| total_loss = 0 | |
| for i, content_image in tqdm.tqdm(enumerate(dataloader), total = len(dataloader), dynamic_ncols = True): | |
| content_image = content_image.to(DEVICE) | |
| optimizer.zero_grad() | |
| reconstruction = decoder(encoder(content_image)[-1]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # check shapes | |
| assert len(target.shape) == 4, "expected 4 dimensions" | |
| assert target.shape == source.shape, "source/target shape mismatch" | |
| batch_size, channels, width, height = source.shape | |
| # calculate target stats | |
| target_reshaped = target.view(batch_size, channels, 1, 1, -1) | |
| target_variances = target_reshaped.var(-1, unbiased = False) | |
| target_means = target_reshaped.mean(-1) |