Skip to content

Instantly share code, notes, and snippets.

View J3698's full-sized avatar
๐Ÿ›
Inch Worm

Anti J3698

๐Ÿ›
Inch Worm
  • Working
  • Working
View GitHub Profile
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.random.seed(round(1000 * random.random()))
length = ceil(self.length // worker_info.num_workers)
if worker_info.id == worker_info.num_workers - 1:
length = self.length - length * (worker_info.num_workers - 1)
assert length > 0
self.ilength = length
class VGG19Encoder(nn.Module):
def __init__(self):
super().__init__()
self.feats = nn.Sequential(*self.extract_vgg19_pretrained_layers())
def extract_vgg19_pretrained_layers(self):
# get pretrained model
features = torchvision.models.vgg19_bn(pretrained=True, progress=True).features
# change to reflection
for i in features:
if isinstance(i, nn.Conv2d):
i.padding_mode = 'reflect'
# get blocks of layers we want
def forward(self, x):
outputs = []
for group in self.feats:
x = group(x)
outputs.append(x)
return outputs
def main():
encoder = VGG19Encoder()
print(encoder)
encoder.freeze()
sample_input = torch.ones((1, 3, 256, 256))
sample_output = encoder(sample_input)
print(f"Input shapes: {sample_input.shape}")
print(f"Output shapes: {[i.shape for i in sample_output]}")
class Decoder(nn.Module):
def __init__(self):
super().__init__()
features = torchvision.models.vgg19(pretrained=True, progress=True).features[20:None:-1]
for i, layer in enumerate(features):
if isinstance(layer, nn.MaxPool2d):
features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest')
elif isinstance(layer, nn.Conv2d):
conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \
for i, layer in enumerate(features):
if isinstance(layer, nn.MaxPool2d):
features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest')
elif isinstance(layer, nn.Conv2d):
conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \
kernel_size = layer.kernel_size, stride = layer.stride, \
padding = layer.padding, padding_mode = 'reflect')
with torch.no_grad():
conv2d.weight[...] = layer.weight.transpose(0, 1)
features[i] = conv2d
def main():
encoder = VGG19Encoder()
decoder = Decoder()
print(decoder)
sample_input = torch.ones((1, 3, 256, 256))
outputs = encoder(sample_input)
output = decoder(outputs[-1])
print(f"Input shape: {sample_input.shape}")
def train_epoch_reconstruct(encoder, decoder, dataloader, optimizer, epoch_num, writer, run):
encoder.train()
decoder.train()
total_loss = 0
for i, content_image in tqdm.tqdm(enumerate(dataloader), total = len(dataloader), dynamic_ncols = True):
content_image = content_image.to(DEVICE)
optimizer.zero_grad()
reconstruction = decoder(encoder(content_image)[-1])
# check shapes
assert len(target.shape) == 4, "expected 4 dimensions"
assert target.shape == source.shape, "source/target shape mismatch"
batch_size, channels, width, height = source.shape
# calculate target stats
target_reshaped = target.view(batch_size, channels, 1, 1, -1)
target_variances = target_reshaped.var(-1, unbiased = False)
target_means = target_reshaped.mean(-1)