Skip to content

Instantly share code, notes, and snippets.

@twobob
Created July 2, 2022 18:27
Show Gist options
  • Save twobob/085e0bc85a51b74d3f7f8537eacfcd5a to your computer and use it in GitHub Desktop.
Save twobob/085e0bc85a51b74d3f7f8537eacfcd5a to your computer and use it in GitHub Desktop.
This is directly borrowed from Yannick Kilcher's work Amended to work on Collab and accept some alternate engine sizes
from utils import train, Pars, create_image, create_outputfolder, init_textfile
from dall_e import map_pixels, unmap_pixels, load_model
from stylegan import g_synthesis
from biggan import BigGAN
from tqdm import tqdm
import create_video
import tempfile
import argparse
import torch
import clip
import glob
import os
import math
# Argsparse for commandline options
parser = argparse.ArgumentParser(description='BigGan_Clip')
parser.add_argument('--epochs',
default = 100,
type = int,
help ='Number of Epochs')
parser.add_argument('--generator',
default = 'biggan',
type = str,
choices = ['biggan','biggan128','biggan256', 'dall-e', 'stylegan'],
help = 'Choose what type of generator you would like to use BigGan or Dall-E')
parser.add_argument('--textfile',
type = str,
required= True,
help ='Path for the text file')
parser.add_argument('--audiofile',
default = None,
type = str,
required= True,
help ='Path for the mp3 file')
parser.add_argument('--lyrics',
default = True,
type = bool,
help ='Include lyrics')
parser.add_argument('--interpolation',
default = 10,
type = int,
help ='Number of elements to be interpolated per second and feed to the model')
args = parser.parse_args()
epochs = args.epochs
generator = args.generator
textfile = args.textfile
audiofile = args.audiofile
interpol = args.interpolation
lyrics = args.lyrics
sideX = 512
sideY = 512
def main():
# Automatically creates 'output' folder
create_outputfolder()
# Initialize Clip
perceptor, preprocess = clip.load('ViT-B/32') #L/14 B/16
perceptor = perceptor.eval()
# Load the model
if generator == 'biggan':
model = BigGAN.from_pretrained('biggan-deep-512')
model = model.cuda().eval()
elif generator == 'biggan128':
model = BigGAN.from_pretrained('biggan-deep-128')
model = model.cuda().eval()
elif generator == 'biggan256':
model = BigGAN.from_pretrained('biggan-deep-256')
model = model.cuda().eval()
elif generator == 'dall-e':
model = load_model("decoder.pkl", 'cuda')
elif generator == 'stylegan':
model = g_synthesis.eval().cuda()
# Read the textfile
# descs - list to append the Description and Timestamps
descs = init_textfile(textfile)
# list of temporary PTFiles
templist = []
# Loop over the description list
for d in tqdm(descs):
timestamp = d[0]
line = d[1]
# stamps_descs_list.append((timestamp, line))
lats = Pars(gen=generator).cuda()
# Init Generator's latents
if generator == 'biggan' or generator == 'biggan128' or generator == 'biggan256':
par = lats.parameters()
lr = 0.1#.07
elif generator == 'stylegan':
par = [lats.normu]
lr = .02
elif generator == 'dall-e':
par = [lats.normu]
lr = .1
# Init optimizer
optimizer = torch.optim.Adam(par, lr)
# tokenize the current description with clip and encode the text
txt = clip.tokenize(line)
percep = perceptor.encode_text(txt.cuda()).detach().clone()
# Training Loop
for i in range(epochs):
zs = train(i, model, lats, sideX, sideY, perceptor, percep, optimizer, line, txt, epochs=epochs, gen=generator)
# save each line's last latent to a torch file temporarily
latent_temp = tempfile.NamedTemporaryFile()
torch.save(zs, latent_temp) #f'./output/pt_folder/{line}.pt')
latent_temp.seek(0)
#append it to templist so it can be accessed later
templist.append(latent_temp)
return templist, descs, model
def sigmoid(x):
x = x * 2. - 1.
return math.tanh(1.5*x/(math.sqrt(1.- math.pow(x, 2.)) + 1e-6)) / 2 + .5
def interpolate(templist, descs, model, audiofile):
video_temp_list = []
# interpole elements between each image
for idx1, pt in enumerate(descs):
# get the next index of the descs list,
# if it z1_idx is out of range, break the loop
z1_idx = idx1 + 1
if z1_idx >= len(descs):
break
current_lyric = pt[1]
# get the interval betwee 2 lines/elements in seconds `ttime`
d1 = pt[0]
d2 = descs[z1_idx][0]
ttime = d2 - d1
# if it is the very first index, load the first pt temp file
# if not assign the previous pt file (z1) to zs variable
if idx1 == 0:
zs = torch.load(templist[idx1])
else:
zs = z1
# compute for the number of elements to be insert between the 2 elements
N = round(ttime * interpol)
print(z1_idx)
# the codes below determine if the output is list (for biggan)
# if not insert it into a list
if not isinstance(zs, list):
z0 = [zs]
z1 = [torch.load(templist[z1_idx])]
else:
z0 = zs
z1 = torch.load(templist[z1_idx])
# loop over the range of elements and generate the images
image_temp_list = []
for t in range(N):
azs = []
for r in zip(z0, z1):
z_diff = r[1] - r[0]
inter_zs = r[0] + sigmoid(t / (N-1)) * z_diff
azs.append(inter_zs)
# Generate image
with torch.no_grad():
if generator == 'biggan' or generator == 'biggan128' or generator == 'biggan256':
img = model(azs[0], azs[1], 1).cpu().numpy()
img = img[0]
elif generator == 'dall-e':
img = unmap_pixels(torch.sigmoid(model(azs[0])[:, :3]).cpu().float()).numpy()
img = img[0]
elif generator == 'stylegan':
img = model(azs[0])
image_temp = create_image(img, t, current_lyric, generator)
image_temp_list.append(image_temp)
video_temp = create_video.createvid(f'{current_lyric}', image_temp_list, duration=ttime / N)
video_temp_list.append(video_temp)
# Finally create the final output and save to output folder
create_video.concatvids(descs, video_temp_list, audiofile, lyrics=lyrics)
if __name__ == '__main__':
templist, descs, model = main()
interpolate(templist, descs, model, audiofile)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment