Skip to content

Instantly share code, notes, and snippets.

@twobob
Created July 2, 2022 18:24
Show Gist options
  • Save twobob/5008fe237dee0900888a1f5c32e9da1a to your computer and use it in GitHub Desktop.
Save twobob/5008fe237dee0900888a1f5c32e9da1a to your computer and use it in GitHub Desktop.
This is directly borrowed from Yannick Kilcher's work Amended to work on Collab and accept some alternate engine sizes
'''
This is directly borrowed from Yannick Kilcher's work
Amended to work on Collab and accept some alternate engine sizes
'''
#/usr/bin/env python3
# Taken from here: https://github.com/huggingface/pytorch-pretrained-BigGAN
# MIT License
#
# Copyright (c) 2019 Thomas Wolf
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch.nn as nn
import torch.nn.functional as F
import torch
import math
import json
import logging
import os
import shutil
import tempfile
from functools import wraps
from hashlib import sha256
import sys
import copy
import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm
WEIGHTS_NAME = 'pytorch_model.bin'
CONFIG_NAME = 'config.json'
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
os.path.join(os.getcwd(), '.pytorch_pretrained_biggan')))
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
os.path.join(os.getcwd(), '.pytorch_pretrained_biggan'))
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename):
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
'''
collection = set()
with open(filename, 'r', encoding='utf-8') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
PRETRAINED_MODEL_ARCHIVE_MAP = {
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
}
class BigGANConfig(object):
""" Configuration class to store the configuration of a `BigGAN`.
Defaults are for the 128x128 model.
layers tuple are (up-sample in the layer ?, input channels, output channels)
"""
def __init__(self,
output_dim=128,
z_dim=128,
class_embed_dim=128,
channel_width=128,
num_classes=1000,
layers=[(False, 16, 16),
(True, 16, 16),
(False, 16, 16),
(True, 16, 8),
(False, 8, 8),
(True, 8, 4),
(False, 4, 4),
(True, 4, 2),
(False, 2, 2),
(True, 2, 1)],
attention_layer_position=8,
eps=1e-4,
n_stats=51):
"""Constructs BigGANConfig. """
self.output_dim = output_dim
self.z_dim = z_dim
self.class_embed_dim = class_embed_dim
self.channel_width = channel_width
self.num_classes = num_classes
self.layers = layers
self.attention_layer_position = attention_layer_position
self.eps = eps
self.n_stats = n_stats
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BigGANConfig` from a Python dictionary of parameters."""
config = BigGANConfig()
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BigGANConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def snconv2d(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
def snlinear(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
def sn_embedding(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
class SelfAttn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_channels, eps=1e-12):
super(SelfAttn, self).__init__()
self.in_channels = in_channels
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
kernel_size=1, bias=False, eps=eps)
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
self.softmax = nn.Softmax(dim=-1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
_, ch, h, w = x.size()
# Theta path
theta = self.snconv1x1_theta(x)
theta = theta.view(-1, ch//8, h*w)
# Phi path
phi = self.snconv1x1_phi(x)
phi = self.maxpool(phi)
phi = phi.view(-1, ch//8, h*w//4)
# Attn map
attn = torch.bmm(theta.permute(0, 2, 1), phi)
attn = self.softmax(attn)
# g path
g = self.snconv1x1_g(x)
g = self.maxpool(g)
g = g.view(-1, ch//2, h*w//4)
# Attn_g - o_conv
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
attn_g = attn_g.view(-1, ch//2, h, w)
attn_g = self.snconv1x1_o_conv(attn_g)
# Out
out = x + self.gamma*attn_g
return out
class BigGANBatchNorm(nn.Module):
""" This is a batch norm module that can handle conditional input and can be provided with pre-computed
activation means and variances for various truncation parameters.
We cannot just rely on torch.batch_norm since it cannot handle
batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
If you want to train this model you should add running means and variance computation logic.
"""
def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
super(BigGANBatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.conditional = conditional
# We use pre-computed statistics for n_stats values of truncation between 0 and 1
self.register_buffer('running_means', torch.zeros(n_stats, num_features))
self.register_buffer('running_vars', torch.ones(n_stats, num_features))
self.step_size = 1.0 / (n_stats - 1)
if conditional:
assert condition_vector_dim is not None
self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
else:
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
def forward(self, x, truncation, condition_vector=None):
# Retreive pre-computed statistics associated to this truncation
coef, start_idx = math.modf(truncation / self.step_size)
start_idx = int(start_idx)
if coef != 0.0: # Interpolate
running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
else:
running_mean = self.running_means[start_idx]
running_var = self.running_vars[start_idx]
if self.conditional:
running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
else:
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
training=False, momentum=0.0, eps=self.eps)
return out
class GenBlock(nn.Module):
def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
n_stats=51, eps=1e-12):
super(GenBlock, self).__init__()
self.up_sample = up_sample
self.drop_channels = (in_size != out_size)
middle_size = in_size // reduction_factor
self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
self.relu = nn.ReLU()
def forward(self, x, cond_vector, truncation):
x0 = x
x = self.bn_0(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_0(x)
x = self.bn_1(x, truncation, cond_vector)
x = self.relu(x)
if self.up_sample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv_1(x)
x = self.bn_2(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_2(x)
x = self.bn_3(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_3(x)
if self.drop_channels:
new_channels = x0.shape[1] // 2
x0 = x0[:, :new_channels, ...]
if self.up_sample:
x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
out = x + x0
return out
class Generator(nn.Module):
def __init__(self, config):
super(Generator, self).__init__()
self.config = config
ch = config.channel_width
condition_vector_dim = config.z_dim * 2
self.gen_z = snlinear(in_features=condition_vector_dim,
out_features=4 * 4 * 16 * ch, eps=config.eps)
layers = []
for i, layer in enumerate(config.layers):
if i == config.attention_layer_position:
layers.append(SelfAttn(ch*layer[1], eps=config.eps))
layers.append(GenBlock(ch*layer[1],
ch*layer[2],
condition_vector_dim,
up_sample=layer[0],
n_stats=config.n_stats,
eps=config.eps))
self.layers = nn.ModuleList(layers)
self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
self.relu = nn.ReLU()
self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
self.tanh = nn.Tanh()
def forward(self, cond_vector, truncation):
z = self.gen_z(cond_vector[0].unsqueeze(0))
# We use this conversion step to be able to use TF weights:
# TF convention on shape is [batch, height, width, channels]
# PT convention on shape is [batch, channels, height, width]
z = z.view(-1, 4, 4, 16 * self.config.channel_width)
z = z.permute(0, 3, 1, 2).contiguous()
for i, layer in enumerate(self.layers):
if isinstance(layer, GenBlock):
z = layer(z, cond_vector[i+1].unsqueeze(0), truncation)
# z = layer(z, cond_vector[].unsqueeze(0), truncation)
else:
z = layer(z)
z = self.bn(z, truncation)
z = self.relu(z)
z = self.conv_to_rgb(z)
z = z[:, :3, ...]
z = self.tanh(z)
return z
class BigGAN(nn.Module):
"""BigGAN Generator."""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
try:
resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error("Wrong model name, should be a valid path to a folder containing "
"a {} file and a {} file or a model name in {}".format(
WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
raise
logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))
# Load config
config = BigGANConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
model.load_state_dict(state_dict, strict=False)
return model
def __init__(self, config):
super(BigGAN, self).__init__()
self.config = config
self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
self.generator = Generator(config)
def forward(self, z, class_label, truncation):
assert 0 < truncation <= 1
embed = self.embeddings(class_label)
cond_vector = torch.cat((z, embed), dim=1)
z = self.generator(cond_vector, truncation)
return z
def one_hot_from_int(int_or_list, batch_size=1):
""" Create a one-hot vector from a class index or a list of class indices.
Params:
int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
batch_size: batch size.
If int_or_list is an int create a batch of identical classes.
If int_or_list is a list, we should have `len(int_or_list) == batch_size`
Output:
array of shape (batch_size, 1000)
"""
if isinstance(int_or_list, int):
int_or_list = [int_or_list]
if len(int_or_list) == 1 and batch_size > 1:
int_or_list = [int_or_list[0]] * batch_size
assert batch_size == len(int_or_list)
array = np.zeros((batch_size, 1000), dtype=np.float32)
for i, j in enumerate(int_or_list):
array[i, j] = 1.0
return array
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment