This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Decoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
features = torchvision.models.vgg19(pretrained=True, progress=True).features[20:None:-1] | |
for i, layer in enumerate(features): | |
if isinstance(layer, nn.MaxPool2d): | |
features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest') | |
elif isinstance(layer, nn.Conv2d): | |
conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
encoder = VGG19Encoder() | |
print(encoder) | |
encoder.freeze() | |
sample_input = torch.ones((1, 3, 256, 256)) | |
sample_output = encoder(sample_input) | |
print(f"Input shapes: {sample_input.shape}") | |
print(f"Output shapes: {[i.shape for i in sample_output]}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, x): | |
outputs = [] | |
for group in self.feats: | |
x = group(x) | |
outputs.append(x) | |
return outputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def extract_vgg19_pretrained_layers(self): | |
# get pretrained model | |
features = torchvision.models.vgg19_bn(pretrained=True, progress=True).features | |
# change to reflection | |
for i in features: | |
if isinstance(i, nn.Conv2d): | |
i.padding_mode = 'reflect' | |
# get blocks of layers we want |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VGG19Encoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.feats = nn.Sequential(*self.extract_vgg19_pretrained_layers()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __iter__(self): | |
worker_info = torch.utils.data.get_worker_info() | |
if worker_info is not None: | |
self.random.seed(round(1000 * random.random())) | |
length = ceil(self.length // worker_info.num_workers) | |
if worker_info.id == worker_info.num_workers - 1: | |
length = self.length - length * (worker_info.num_workers - 1) | |
assert length > 0 | |
self.ilength = length |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
transforms = get_transforms() | |
dataset = StyleTransferDataset("datasets/coco/train2017", "datasets/coco/annotations/captions_train2017.json", "datasets/wikiart", transform = transforms) | |
print(f"Dataset length: {len(dataset)}, Wiki length: {len(dataset.wiki)}, COCO length: {len(dataset.coco)}") | |
content, style = dataset[0] | |
print(f"1st content img: {type(content)}, {content.shape}") | |
print(f"1st style img: {type(style)}, {style.shape}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_transforms(crop): | |
if crop: | |
return Compose([Resize(512), RandomCrop(256), ToTensor()]) | |
return Compose([Resize((256, 256)), ToTensor()]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class IterableStyleTransferDataset(IterableDataset): | |
def __init__(self, coco_path, coco_annotations, \ | |
wiki_path, length = 100000, transform = None, rng_seed = 1, exclude_style = False): | |
self.wiki = ImageFolder(wiki_path, transform = transform) | |
self.coco = CocoCaptions(coco_path, coco_annotations, transform = transform) | |
self.length = length | |
self.exclude_style = exclude_style | |
self.seed = rng_seed |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from selenium.webdriver.support import expected_conditions as EC | |
from selenium import webdriver | |
from selenium.common.exceptions import NoSuchElementException | |
from selenium.webdriver.chrome.options import Options | |
from selenium.webdriver.support import expected_conditions as EC | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.support.ui import WebDriverWait | |
def login_to_autolab(wd, username, password): | |
wd.get("https://autolab.andrew.cmu.edu") |