This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __iter__(self): | |
worker_info = torch.utils.data.get_worker_info() | |
if worker_info is not None: | |
self.random.seed(round(1000 * random.random())) | |
length = ceil(self.length // worker_info.num_workers) | |
if worker_info.id == worker_info.num_workers - 1: | |
length = self.length - length * (worker_info.num_workers - 1) | |
assert length > 0 | |
self.ilength = length |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VGG19Encoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.feats = nn.Sequential(*self.extract_vgg19_pretrained_layers()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def extract_vgg19_pretrained_layers(self): | |
# get pretrained model | |
features = torchvision.models.vgg19_bn(pretrained=True, progress=True).features | |
# change to reflection | |
for i in features: | |
if isinstance(i, nn.Conv2d): | |
i.padding_mode = 'reflect' | |
# get blocks of layers we want |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, x): | |
outputs = [] | |
for group in self.feats: | |
x = group(x) | |
outputs.append(x) | |
return outputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
encoder = VGG19Encoder() | |
print(encoder) | |
encoder.freeze() | |
sample_input = torch.ones((1, 3, 256, 256)) | |
sample_output = encoder(sample_input) | |
print(f"Input shapes: {sample_input.shape}") | |
print(f"Output shapes: {[i.shape for i in sample_output]}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Decoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
features = torchvision.models.vgg19(pretrained=True, progress=True).features[20:None:-1] | |
for i, layer in enumerate(features): | |
if isinstance(layer, nn.MaxPool2d): | |
features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest') | |
elif isinstance(layer, nn.Conv2d): | |
conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for i, layer in enumerate(features): | |
if isinstance(layer, nn.MaxPool2d): | |
features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest') | |
elif isinstance(layer, nn.Conv2d): | |
conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \ | |
kernel_size = layer.kernel_size, stride = layer.stride, \ | |
padding = layer.padding, padding_mode = 'reflect') | |
with torch.no_grad(): | |
conv2d.weight[...] = layer.weight.transpose(0, 1) | |
features[i] = conv2d |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
encoder = VGG19Encoder() | |
decoder = Decoder() | |
print(decoder) | |
sample_input = torch.ones((1, 3, 256, 256)) | |
outputs = encoder(sample_input) | |
output = decoder(outputs[-1]) | |
print(f"Input shape: {sample_input.shape}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_epoch_reconstruct(encoder, decoder, dataloader, optimizer, epoch_num, writer, run): | |
encoder.train() | |
decoder.train() | |
total_loss = 0 | |
for i, content_image in tqdm.tqdm(enumerate(dataloader), total = len(dataloader), dynamic_ncols = True): | |
content_image = content_image.to(DEVICE) | |
optimizer.zero_grad() | |
reconstruction = decoder(encoder(content_image)[-1]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# check shapes | |
assert len(target.shape) == 4, "expected 4 dimensions" | |
assert target.shape == source.shape, "source/target shape mismatch" | |
batch_size, channels, width, height = source.shape | |
# calculate target stats | |
target_reshaped = target.view(batch_size, channels, 1, 1, -1) | |
target_variances = target_reshaped.var(-1, unbiased = False) | |
target_means = target_reshaped.mean(-1) |