Skip to content

Instantly share code, notes, and snippets.

View J3698's full-sized avatar
๐Ÿ›
Inch Worm

Anti J3698

๐Ÿ›
Inch Worm
  • Working
  • Working
View GitHub Profile
class Decoder(nn.Module):
def __init__(self):
super().__init__()
features = torchvision.models.vgg19(pretrained=True, progress=True).features[20:None:-1]
for i, layer in enumerate(features):
if isinstance(layer, nn.MaxPool2d):
features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest')
elif isinstance(layer, nn.Conv2d):
conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \
def main():
encoder = VGG19Encoder()
print(encoder)
encoder.freeze()
sample_input = torch.ones((1, 3, 256, 256))
sample_output = encoder(sample_input)
print(f"Input shapes: {sample_input.shape}")
print(f"Output shapes: {[i.shape for i in sample_output]}")
def forward(self, x):
outputs = []
for group in self.feats:
x = group(x)
outputs.append(x)
return outputs
def extract_vgg19_pretrained_layers(self):
# get pretrained model
features = torchvision.models.vgg19_bn(pretrained=True, progress=True).features
# change to reflection
for i in features:
if isinstance(i, nn.Conv2d):
i.padding_mode = 'reflect'
# get blocks of layers we want
class VGG19Encoder(nn.Module):
def __init__(self):
super().__init__()
self.feats = nn.Sequential(*self.extract_vgg19_pretrained_layers())
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.random.seed(round(1000 * random.random()))
length = ceil(self.length // worker_info.num_workers)
if worker_info.id == worker_info.num_workers - 1:
length = self.length - length * (worker_info.num_workers - 1)
assert length > 0
self.ilength = length
def main():
transforms = get_transforms()
dataset = StyleTransferDataset("datasets/coco/train2017", "datasets/coco/annotations/captions_train2017.json", "datasets/wikiart", transform = transforms)
print(f"Dataset length: {len(dataset)}, Wiki length: {len(dataset.wiki)}, COCO length: {len(dataset.coco)}")
content, style = dataset[0]
print(f"1st content img: {type(content)}, {content.shape}")
print(f"1st style img: {type(style)}, {style.shape}")
def get_transforms(crop):
if crop:
return Compose([Resize(512), RandomCrop(256), ToTensor()])
return Compose([Resize((256, 256)), ToTensor()])
class IterableStyleTransferDataset(IterableDataset):
def __init__(self, coco_path, coco_annotations, \
wiki_path, length = 100000, transform = None, rng_seed = 1, exclude_style = False):
self.wiki = ImageFolder(wiki_path, transform = transform)
self.coco = CocoCaptions(coco_path, coco_annotations, transform = transform)
self.length = length
self.exclude_style = exclude_style
self.seed = rng_seed
@J3698
J3698 / login.py
Last active February 2, 2021 19:00
Duo
from selenium.webdriver.support import expected_conditions as EC
from selenium import webdriver
from selenium.common.exceptions import NoSuchElementException
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
def login_to_autolab(wd, username, password):
wd.get("https://autolab.andrew.cmu.edu")