Created
December 31, 2021 02:56
-
-
Save afiaka87/30d0fef1392fec4d4236d11d8917177c to your computer and use it in GitHub Desktop.
Finetune GLIDE (small filtered) from Open AI. WIP.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import sys | |
sys.path.append("./glide-text2im") | |
import torch as th | |
from glide_text2im.download import load_checkpoint | |
from glide_text2im.model_creation import (create_model_and_diffusion, | |
model_and_diffusion_defaults) | |
from guided_diffusion import dist_util, logger | |
from guided_diffusion.image_text_datasets import load_data | |
from guided_diffusion.resample import create_named_schedule_sampler | |
from guided_diffusion.script_util import add_dict_to_argparser | |
from guided_diffusion.train_util import TrainLoop | |
def main(): | |
args = create_argparser().parse_args() | |
dist_util.setup_dist() | |
logger.configure() | |
_device = dist_util.dev() | |
# Create base model. | |
glide_options = model_and_diffusion_defaults() | |
glide_model, diffusion = create_model_and_diffusion(**glide_options) | |
glide_model.convert_to_fp16() | |
glide_model.to(_device) | |
glide_model.load_state_dict(load_checkpoint('base', _device)) | |
logger.log('total base parameters', sum(x.numel() | |
for x in glide_model.parameters())) | |
schedule_sampler = create_named_schedule_sampler( | |
args.schedule_sampler, diffusion) | |
logger.log("creating data loader...") | |
data = load_latent_data( | |
data_dir=args.data_dir, | |
batch_size=args.batch_size, | |
model=glide_model, | |
options=glide_options, | |
device=_device, | |
) | |
logger.log("training...") | |
TrainLoop( | |
model=glide_model, | |
diffusion=diffusion, | |
data=data, | |
batch_size=args.batch_size, | |
microbatch=args.microbatch, | |
lr=args.lr, | |
ema_rate=args.ema_rate, | |
log_interval=args.log_interval, | |
save_interval=args.save_interval, | |
resume_checkpoint=args.resume_checkpoint, | |
use_fp16=args.use_fp16, | |
fp16_scale_growth=args.fp16_scale_growth, | |
schedule_sampler=schedule_sampler, | |
weight_decay=args.weight_decay, | |
lr_anneal_steps=args.lr_anneal_steps, | |
).run_loop() | |
def load_latent_data(model, options, data_dir, batch_size, device): | |
data = load_data( | |
data_dir=data_dir, | |
batch_size=batch_size, | |
image_size=64, | |
class_cond=False, | |
) | |
for batch, model_kwargs, text in data: | |
tokens = model.tokenizer.encode(text[0]) | |
tokens, mask = model.tokenizer.padded_tokens_and_mask( | |
tokens, options['text_ctx']) | |
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask( | |
[], options['text_ctx']) | |
tokens = th.tensor([tokens] * batch_size + [uncond_tokens] | |
* batch_size, device=device, dtype=th.half) | |
mask = th.tensor([mask] * batch_size + [uncond_mask] | |
* batch_size, dtype=th.bool, device=device) | |
# model_kwargs["xf_proj"] = tokens | |
# model_kwargs["xf_out"] = uncond_tokens | |
model_kwargs["tokens"] = tokens | |
model_kwargs["mask"] = mask | |
batch = batch.to(dist_util.dev()) | |
yield batch, model_kwargs | |
def create_argparser(): | |
defaults = dict( | |
data_dir="", | |
schedule_sampler="uniform", | |
lr=1e-4, | |
weight_decay=0.0, | |
lr_anneal_steps=0, | |
batch_size=1, | |
microbatch=-1, # -1 disables microbatches | |
ema_rate="0.9999", # comma-separated list of EMA values | |
log_interval=10, | |
save_interval=10000, | |
resume_checkpoint="", | |
use_fp16=True, | |
fp16_scale_growth=1e-3, | |
) | |
defaults.update(model_and_diffusion_defaults()) | |
defaults['encoder_channels'] = 512 | |
parser = argparse.ArgumentParser() | |
add_dict_to_argparser(parser, defaults) | |
return parser | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment