Created
November 16, 2022 02:16
-
-
Save chavinlo/335266a3a6825ffafbec191e7d0e35bd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Install bitsandbytes: | |
# `nvcc --version` to get CUDA version. | |
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. | |
# Example Usage: | |
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True | |
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True | |
import argparse | |
import socket | |
import torch | |
import torchvision | |
import transformers | |
import diffusers | |
import os | |
import glob | |
import random | |
import tqdm | |
import resource | |
import psutil | |
import pynvml | |
import wandb | |
import gc | |
import time | |
import itertools | |
import numpy as np | |
import json | |
import re | |
import traceback | |
#Distributed only | |
import hivemind | |
import requests | |
import zipfile | |
import shutil | |
from hivemind.optim import power_sgd_averager | |
try: | |
pynvml.nvmlInit() | |
except pynvml.nvml.NVMLError_LibraryNotFound: | |
pynvml = None | |
from typing import Iterable | |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline | |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
from diffusers.optimization import get_scheduler | |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
from PIL import Image, ImageOps | |
from typing import Dict, List, Generator, Tuple | |
from scipy.interpolate import interp1d | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# defaults should be good for everyone | |
# TODO: add custom VAE support. should be simple with diffusers | |
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner') | |
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory') | |
parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.') | |
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.') | |
#parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.') | |
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.') | |
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.') | |
parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.') | |
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate') | |
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for') | |
parser.add_argument('--batch_size', type=int, default=1, help='Batch size') | |
parser.add_argument('--use_ema', type=str, default='False', help='Use EMA for finetuning') | |
parser.add_argument('--ucg', type=float, default=0.1, help='Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout.') # 10% dropout probability | |
parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', type=str, default='False', help='Enable gradient checkpointing') | |
parser.add_argument('--use_8bit_adam', dest='use_8bit_adam', type=str, default='False', help='Use 8-bit Adam optimizer') | |
parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1') | |
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2') | |
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay') | |
parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon') | |
parser.add_argument('--lr_scheduler', type=str, default='cosine', help='Learning rate scheduler [`cosine`, `linear`, `constant`]') | |
parser.add_argument('--lr_scheduler_warmup', type=float, default=0.05, help='Learning rate scheduler warmup steps. This is a percentage of the total number of steps in the training run. 0.1 means 10 percent of the total number of steps.') | |
parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.') | |
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.') | |
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.') | |
parser.add_argument('--resolution', type=int, default=512, help='Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down.') | |
parser.add_argument('--shuffle', dest='shuffle', type=str, default='True', help='Shuffle dataset') | |
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.') | |
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB') | |
parser.add_argument('--fp16', dest='fp16', type=str, default='False', help='Train in mixed precision') | |
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') | |
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') | |
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.') | |
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.') | |
parser.add_argument('--clip_penultimate', type=str, default='False', help='Use penultimate CLIP layer for text embedding') | |
parser.add_argument('--output_bucket_info', type=str, default='False', help='Outputs bucket information and exits') | |
parser.add_argument('--resize', type=str, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") | |
parser.add_argument('--use_xformers', type=str, default='False', help='Use memory efficient attention') | |
#Modified | |
parser.add_argument('--wandb', dest='enablewandb', type=str, default='False', help='Enable WeightsAndBiases Reporting') | |
parser.add_argument('--inference', dest='enableinference', type=str, default='False', help='Enable Inference during training (Consumes 2GB of VRAM)') | |
#Hivemind only | |
#parser.add_argument('--hivemind', dest='enablehivemind', type=str, default='True', help='Enable Hivemind usage)') | |
parser.add_argument('--peers', type=str, default=None, nargs="*", help='MUST BE PASSED AS A LIST! ex.: --peers /ipv4/1.1.1.1 /ipv4/2.2.2.2 | Multiaddrs of one or more active DHT peers. If none it will start a new session.') | |
#Dataset server | |
parser.add_argument('--datasetserver', type=str, dest='datasetserver', default=None, help='Address of dataset server') | |
parser.add_argument('--wantedimages', type=int, dest='wantedimages', default=None, help='Number of wanted images') | |
parser.add_argument('--workingdirectory', type=str, dest='workingdirectory', default="distributed_data", help='Folder where the downloader is going to do its work') | |
args = parser.parse_args() | |
for arg in vars(args): | |
if type(getattr(args, arg)) == str: | |
if getattr(args, arg).lower() == 'true': | |
setattr(args, arg, True) | |
elif getattr(args, arg).lower() == 'false': | |
setattr(args, arg, False) | |
def setup(): | |
torch.distributed.init_process_group("nccl", init_method="env://") | |
def cleanup(): | |
torch.distributed.destroy_process_group() | |
def get_rank() -> int: | |
if not torch.distributed.is_initialized(): | |
return 0 | |
return torch.distributed.get_rank() | |
def get_world_size() -> int: | |
if not torch.distributed.is_initialized(): | |
return 1 | |
return torch.distributed.get_world_size() | |
def get_gpu_ram() -> str: | |
""" | |
Returns memory usage statistics for the CPU, GPU, and Torch. | |
:return: | |
""" | |
gpu_str = "" | |
torch_str = "" | |
try: | |
cudadev = torch.cuda.current_device() | |
nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev) | |
gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device) | |
gpu_total = int(gpu_info.total / 1E6) | |
gpu_free = int(gpu_info.free / 1E6) | |
gpu_used = int(gpu_info.used / 1E6) | |
gpu_str = f"GPU: (U: {gpu_used:,}mb F: {gpu_free:,}mb " \ | |
f"T: {gpu_total:,}mb) " | |
torch_reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1E6) | |
torch_reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1E6) | |
torch_used_gpu = int(torch.cuda.memory_allocated() / 1E6) | |
torch_max_used_gpu = int(torch.cuda.max_memory_allocated() / 1E6) | |
torch_str = f"TORCH: (R: {torch_reserved_gpu:,}mb/" \ | |
f"{torch_reserved_max:,}mb, " \ | |
f"A: {torch_used_gpu:,}mb/{torch_max_used_gpu:,}mb)" | |
except AssertionError: | |
pass | |
cpu_maxrss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1E3 + | |
resource.getrusage( | |
resource.RUSAGE_CHILDREN).ru_maxrss / 1E3) | |
cpu_vmem = psutil.virtual_memory() | |
cpu_free = int(cpu_vmem.free / 1E6) | |
return f"CPU: (maxrss: {cpu_maxrss:,}mb F: {cpu_free:,}mb) " \ | |
f"{gpu_str}" \ | |
f"{torch_str}" | |
datasetServer = args.datasetserver | |
wantedImages = args.wantedimages | |
workingDirectory = args.workingdirectory | |
if os.path.exists(workingDirectory + "/tmp"): | |
print("Warning, tmp folder will be cleared in 10 secs") | |
time.sleep(10) | |
shutil.rmtree(workingDirectory + "/tmp") | |
if datasetServer is None: | |
print("No dataset server chosen.") | |
datasetServer = str(input("Dataset Server: ")) | |
else: | |
print("Dataset server is: " + datasetServer) | |
if wantedImages is None: | |
wantedImages = int(input("How many images to download each time?: ")) | |
print("Number of images to download each time: " + str(wantedImages)) | |
print("Attempting to get server info...") | |
#ex.: datasetServer = 127.0.0.1 | |
r = requests.get('http://' + str(datasetServer) + '/info') | |
if r.status_code == 200: | |
data = json.loads(r.text) | |
print("Server: " + data['ServerName']) | |
print(data['ServerDescription']) | |
print("Server Version: " + data['ServerVersion']) | |
print("Currently serving " + str(data['FilesBeingServed']) + " Files") | |
print("Age: " + data['ExecutedAt']) | |
else: | |
print("Unable to get server info") | |
exit() | |
directoryToExtract = workingDirectory + "/tmp/dataset" | |
print("directoryToExtract: " + directoryToExtract) | |
print("Wokring: " + workingDirectory) | |
os.makedirs(workingDirectory, exist_ok=True) | |
import sys | |
def download_file(url, inputjson, output): | |
link = url | |
file_name = output | |
with open(file_name, "wb") as f: | |
print("Downloading %s" % file_name) | |
response = requests.post(link, stream=True, json=inputjson) | |
total_length = response.headers.get('content-length') | |
if total_length is None: # no content length header | |
f.write(response.content) | |
else: | |
dl = 0 | |
total_length = int(total_length) | |
for data in response.iter_content(chunk_size=4096): | |
dl += len(data) | |
f.write(data) | |
done = int(50 * dl / total_length) | |
sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50-done)) ) | |
sys.stdout.flush() | |
def onlineGather(datasetServer, wantedImages, directoryToExtract): | |
#ex.: datasetServer = "127.0.0.1" assuming port is 80 | |
print("Dataset server is: " + str(datasetServer)) | |
#Info on how this works should be on a md file soon | |
urlDomain = 'http://' + datasetServer | |
urlGetTasks = urlDomain + '/v1/get/tasks/' + str(wantedImages) | |
requestGetTasks = requests.get(urlGetTasks) | |
responseAsJson = requestGetTasks.json() | |
print("Downloading Files...") | |
downloadUrl = urlDomain + "/v1/get/files" | |
#TODO: fix memory file | |
#print("Saving as BytesIO") | |
#memory_file = BytesIO() | |
tmpZipFilename = workingDirectory + "/tmp.zip" | |
download_file(downloadUrl, responseAsJson, tmpZipFilename) | |
#memory_file.seek(0) | |
print("Unzipping...") | |
with zipfile.ZipFile(tmpZipFilename, 'r') as zip_ref: | |
print("Extracting to: " + directoryToExtract) | |
zip_ref.extractall(directoryToExtract) | |
print("Extracted") | |
os.remove(tmpZipFilename) | |
responseRecipt = responseAsJson | |
return(responseRecipt) | |
def onlineReport(datasetServer, recipt): | |
print("Reporting epoch completition...") | |
urlDomain = 'http://' + datasetServer | |
urlReport = urlDomain + '/v1/post/epochcount' | |
postReportEpoch = requests.post(urlReport, json=recipt) | |
if postReportEpoch.status_code == 200: | |
return True | |
else: | |
return False | |
def _sort_by_ratio(bucket: tuple) -> float: | |
return bucket[0] / bucket[1] | |
def _sort_by_area(bucket: tuple) -> float: | |
return bucket[0] * bucket[1] | |
class ImageStore: | |
def __init__(self, data_dir: str) -> None: | |
self.data_dir = data_dir | |
self.image_files = [] | |
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] | |
self.image_files = [x for x in self.image_files if self.__valid_file(x)] | |
def __len__(self) -> int: | |
return len(self.image_files) | |
def __valid_file(self, f) -> bool: | |
try: | |
Image.open(f) | |
return True | |
except: | |
print(f'WARNING: Unable to open file: {f}') | |
return False | |
# iterator returns images as PIL images and their index in the store | |
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: | |
for f in range(len(self)): | |
yield Image.open(self.image_files[f]).convert(mode='RGB'), f | |
# get image by index | |
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: | |
return Image.open(self.image_files[ref[0]]).convert(mode='RGB') | |
# gets caption by removing the extension from the filename and replacing it with .txt | |
def get_caption(self, ref: Tuple[int, int, int]) -> str: | |
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt' | |
with open(filename, 'r', encoding='UTF-8') as f: | |
return f.read() | |
# ====================================== # | |
# Bucketing code stolen from hasuwoof: # | |
# https://github.com/hasuwoof/huskystack # | |
# ====================================== # | |
class AspectBucket: | |
def __init__(self, store: ImageStore, | |
num_buckets: int, | |
batch_size: int, | |
bucket_side_min: int = 256, | |
bucket_side_max: int = 768, | |
bucket_side_increment: int = 64, | |
max_image_area: int = 512 * 768, | |
max_ratio: float = 2): | |
self.requested_bucket_count = num_buckets | |
self.bucket_length_min = bucket_side_min | |
self.bucket_length_max = bucket_side_max | |
self.bucket_increment = bucket_side_increment | |
self.max_image_area = max_image_area | |
self.batch_size = batch_size | |
self.total_dropped = 0 | |
if max_ratio <= 0: | |
self.max_ratio = float('inf') | |
else: | |
self.max_ratio = max_ratio | |
self.store = store | |
self.buckets = [] | |
self._bucket_ratios = [] | |
self._bucket_interp = None | |
self.bucket_data: Dict[tuple, List[int]] = dict() | |
self.init_buckets() | |
self.fill_buckets() | |
def init_buckets(self): | |
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment)) | |
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths) | |
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio) | |
buckets_by_ratio = {} | |
# group the buckets by their aspect ratios | |
for bucket in possible_buckets: | |
w, h = bucket | |
# use precision to avoid spooky floats messing up your day | |
ratio = '{:.4e}'.format(w / h) | |
if ratio not in buckets_by_ratio: | |
group = set() | |
buckets_by_ratio[ratio] = group | |
else: | |
group = buckets_by_ratio[ratio] | |
group.add(bucket) | |
# now we take the list of buckets we generated and pick the largest by area for each (the first sorted) | |
# then we put all of those in a list, sorted by the aspect ratio | |
# the square bucket (LxL) will be the first | |
unique_ratio_buckets = sorted([sorted(buckets, key=_sort_by_area)[-1] | |
for buckets in buckets_by_ratio.values()], key=_sort_by_ratio) | |
# how many buckets to create for each side of the distribution | |
bucket_count_each = int(np.clip((self.requested_bucket_count + 1) / 2, 1, len(unique_ratio_buckets))) | |
# we know that the requested_bucket_count must be an odd number, so the indices we calculate | |
# will include the square bucket and some linearly spaced buckets along the distribution | |
indices = {*np.linspace(0, len(unique_ratio_buckets) - 1, bucket_count_each, dtype=int)} | |
# make the buckets, make sure they are unique (to remove the duplicated square bucket), and sort them by ratio | |
# here we add the portrait buckets by reversing the dimensions of the landscape buckets we generated above | |
buckets = sorted({*(unique_ratio_buckets[i] for i in indices), | |
*(tuple(reversed(unique_ratio_buckets[i])) for i in indices)}, key=_sort_by_ratio) | |
self.buckets = buckets | |
# cache the bucket ratios and the interpolator that will be used for calculating the best bucket later | |
# the interpolator makes a 1d piecewise interpolation where the input (x-axis) is the bucket ratio, | |
# and the output is the bucket index in the self.buckets array | |
# to find the best fit we can just round that number to get the index | |
self._bucket_ratios = [w / h for w, h in buckets] | |
self._bucket_interp = interp1d(self._bucket_ratios, list(range(len(buckets))), assume_sorted=True, | |
fill_value=None) | |
for b in buckets: | |
self.bucket_data[b] = [] | |
def get_batch_count(self): | |
return sum(len(b) // self.batch_size for b in self.bucket_data.values()) | |
def get_bucket_info(self): | |
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios }) | |
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]: | |
""" | |
Generator that provides batches where the images in a batch fall on the same bucket | |
Each element generated will be: | |
(index, w, h) | |
where each image is an index into the dataset | |
:return: | |
""" | |
max_bucket_len = max(len(b) for b in self.bucket_data.values()) | |
index_schedule = list(range(max_bucket_len)) | |
random.shuffle(index_schedule) | |
bucket_len_table = { | |
b: len(self.bucket_data[b]) for b in self.buckets | |
} | |
bucket_schedule = [] | |
for i, b in enumerate(self.buckets): | |
bucket_schedule.extend([i] * (bucket_len_table[b] // self.batch_size)) | |
random.shuffle(bucket_schedule) | |
bucket_pos = { | |
b: 0 for b in self.buckets | |
} | |
total_generated_by_bucket = { | |
b: 0 for b in self.buckets | |
} | |
for bucket_index in bucket_schedule: | |
b = self.buckets[bucket_index] | |
i = bucket_pos[b] | |
bucket_len = bucket_len_table[b] | |
batch = [] | |
while len(batch) != self.batch_size: | |
# advance in the schedule until we find an index that is contained in the bucket | |
k = index_schedule[i] | |
if k < bucket_len: | |
entry = self.bucket_data[b][k] | |
batch.append(entry) | |
i += 1 | |
total_generated_by_bucket[b] += self.batch_size | |
bucket_pos[b] = i | |
yield [(idx, *b) for idx in batch] | |
def fill_buckets(self): | |
entries = self.store.entries_iterator() | |
total_dropped = 0 | |
for entry, index in tqdm.tqdm(entries, total=len(self.store)): | |
if not self._process_entry(entry, index): | |
total_dropped += 1 | |
for b, values in self.bucket_data.items(): | |
# shuffle the entries for extra randomness and to make sure dropped elements are also random | |
random.shuffle(values) | |
# make sure the buckets have an exact number of elements for the batch | |
to_drop = len(values) % self.batch_size | |
self.bucket_data[b] = list(values[:len(values) - to_drop]) | |
total_dropped += to_drop | |
self.total_dropped = total_dropped | |
def _process_entry(self, entry: Image.Image, index: int) -> bool: | |
aspect = entry.width / entry.height | |
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio: | |
return False | |
best_bucket = self._bucket_interp(aspect) | |
if best_bucket is None: | |
return False | |
bucket = self.buckets[round(float(best_bucket))] | |
self.bucket_data[bucket].append(index) | |
del entry | |
return True | |
class AspectBucketSampler(torch.utils.data.Sampler): | |
def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0): | |
super().__init__(None) | |
self.bucket = bucket | |
self.num_replicas = num_replicas | |
self.rank = rank | |
def __iter__(self): | |
# subsample the bucket to only include the elements that are assigned to this rank | |
indices = self.bucket.get_batch_iterator() | |
indices = list(indices)[self.rank::self.num_replicas] | |
return iter(indices) | |
def __len__(self): | |
return self.bucket.get_batch_count() // self.num_replicas | |
class AspectDataset(torch.utils.data.Dataset): | |
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1): | |
self.store = store | |
self.tokenizer = tokenizer | |
self.ucg = ucg | |
self.transforms = torchvision.transforms.Compose([ | |
torchvision.transforms.RandomHorizontalFlip(p=0.5), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize([0.5], [0.5]) | |
]) | |
def __len__(self): | |
return len(self.store) | |
def __getitem__(self, item: Tuple[int, int, int]): | |
return_dict = {'pixel_values': None, 'input_ids': None} | |
image_file = self.store.get_image(item) | |
if args.resize: | |
image_file = ImageOps.fit( | |
image_file, | |
(item[1], item[2]), | |
bleed=0.0, | |
centering=(0.5, 0.5), | |
method=Image.Resampling.LANCZOS | |
) | |
return_dict['pixel_values'] = self.transforms(image_file) | |
if random.random() > self.ucg: | |
caption_file = self.store.get_caption(item) | |
else: | |
caption_file = '' | |
return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids | |
return return_dict | |
def collate_fn(self, examples): | |
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None]) | |
pixel_values.to(memory_format=torch.contiguous_format).float() | |
input_ids = [example['input_ids'] for example in examples if example is not None] | |
padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True) | |
return { | |
'pixel_values': pixel_values, | |
'input_ids': padded_tokens.input_ids, | |
'attention_mask': padded_tokens.attention_mask, | |
} | |
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | |
class EMAModel: | |
""" | |
Exponential Moving Average of models weights | |
""" | |
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | |
parameters = list(parameters) | |
self.shadow_params = [p.clone().detach() for p in parameters] | |
self.decay = decay | |
self.optimization_step = 0 | |
def get_decay(self, optimization_step): | |
""" | |
Compute the decay factor for the exponential moving average. | |
""" | |
value = (1 + optimization_step) / (10 + optimization_step) | |
return 1 - min(self.decay, value) | |
@torch.no_grad() | |
def step(self, parameters): | |
parameters = list(parameters) | |
self.optimization_step += 1 | |
self.decay = self.get_decay(self.optimization_step) | |
for s_param, param in zip(self.shadow_params, parameters): | |
if param.requires_grad: | |
tmp = self.decay * (s_param - param) | |
s_param.sub_(tmp) | |
else: | |
s_param.copy_(param) | |
torch.cuda.empty_cache() | |
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | |
""" | |
Copy current averaged parameters into given collection of parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored moving averages. If `None`, the | |
parameters with which this `ExponentialMovingAverage` was | |
initialized will be used. | |
""" | |
parameters = list(parameters) | |
for s_param, param in zip(self.shadow_params, parameters): | |
param.data.copy_(s_param.data) | |
# From CompVis LitEMA implementation | |
def store(self, parameters): | |
""" | |
Save the current parameters for restoring later. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
temporarily stored. | |
""" | |
self.collected_params = [param.clone() for param in parameters] | |
def restore(self, parameters): | |
""" | |
Restore the parameters stored with the `store` method. | |
Useful to validate the model with EMA parameters without affecting the | |
original optimization process. Store the parameters before the | |
`copy_to` method. After validation (or model saving), use this to | |
restore the former parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored parameters. | |
""" | |
for c_param, param in zip(self.collected_params, parameters): | |
param.data.copy_(c_param.data) | |
del self.collected_params | |
gc.collect() | |
def to(self, device=None, dtype=None) -> None: | |
r"""Move internal buffers of the ExponentialMovingAverage to `device`. | |
Args: | |
device: like `device` argument to `torch.Tensor.to` | |
""" | |
# .to() on the tensors handles None correctly | |
self.shadow_params = [ | |
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | |
for p in self.shadow_params | |
] | |
def hivemindWorker(optimizer, peersArg=None): | |
init_peers = peersArg | |
optimizer = optimizer | |
if init_peers is not None: | |
dht = hivemind.DHT( | |
host_maddrs=["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"], | |
initial_peers=init_peers, | |
start=True | |
) | |
print("Type: Relay") | |
else: | |
dht = hivemind.DHT( | |
host_maddrs=["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"], | |
start=True | |
) | |
print("Type: New") | |
print('\n'.join(str(addr) for addr in dht.get_visible_maddrs())) | |
print("Global IP:", hivemind.utils.networking.choose_ip_address(dht.get_visible_maddrs())) | |
from functools import partial | |
hm_opt = hivemind.Optimizer( | |
dht=dht, # use a DHT that is connected with other peers | |
run_id='test_run', # unique identifier of this collaborative run | |
batch_size_per_step=1, # each call to opt.step adds this many samples towards the next epoch | |
target_batch_size=1000, # after peers collectively process this many samples, average weights and begin the next epoch | |
optimizer=optimizer, # wrap the SGD optimizer defined above | |
use_local_updates=True, # perform optimizer steps with local gradients, average parameters in background | |
matchmaking_time=1500.0, # when averaging parameters, gather peers in background for up to this many seconds | |
averaging_timeout=1500.0, # give up on averaging if not successful in this many seconds | |
verbose=True, # print logs incessently | |
) | |
return(hm_opt) | |
def main(): | |
rank = get_rank() | |
world_size = get_world_size() | |
torch.cuda.set_device(rank) | |
enablewandb = args.enablewandb | |
enableinference = args.enableinference | |
if rank == 0: | |
os.makedirs(args.output_path, exist_ok=True) | |
if enablewandb: | |
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb') | |
else: | |
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode="disabled") | |
# Inform the user of host, and various versions -- useful for debugging issues. | |
print("RUN_NAME:", args.run_name) | |
print("HOST:", socket.gethostname()) | |
print("CUDA:", torch.version.cuda) | |
print("TORCH:", torch.__version__) | |
print("TRANSFORMERS:", transformers.__version__) | |
print("DIFFUSERS:", diffusers.__version__) | |
print("MODEL:", args.model) | |
print("FP16:", args.fp16) | |
print("RESOLUTION:", args.resolution) | |
if args.hf_token is None: | |
try: | |
args.hf_token = os.environ['HF_API_TOKEN'] | |
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') | |
except Exception: | |
print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)") | |
args.hf_token = "none" | |
device = torch.device('cuda') | |
print("DEVICE:", device) | |
# setup fp16 stuff | |
scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) | |
# Set seed | |
torch.manual_seed(args.seed) | |
print('RANDOM SEED:', args.seed) | |
if args.resume: | |
args.model = args.resume | |
tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token) | |
text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token) | |
vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token) | |
unet = UNet2DConditionModel.from_pretrained(args.model, subfolder='unet', use_auth_token=args.hf_token) | |
#Move the models before initializing the optimizer | |
weight_dtype = torch.float16 if args.fp16 else torch.float32 | |
# move models to device | |
vae = vae.to(device, dtype=weight_dtype) | |
unet = unet.to(device, dtype=torch.float32) | |
text_encoder = text_encoder.to(device, dtype=weight_dtype) | |
# Freeze vae and text_encoder | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
if args.gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
if args.use_xformers: | |
unet.set_use_memory_efficient_attention_xformers(True) | |
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. | |
# if str2optimizer8bit_blockwise check https://github.com/TimDettmers/bitsandbytes/issues/62 | |
try: | |
import bitsandbytes as bnb | |
optimizer_cls = bnb.optim.AdamW8bit | |
except: | |
print('bitsandbytes not supported, using regular Adam optimizer') | |
optimizer_cls = torch.optim.AdamW | |
else: | |
optimizer_cls = torch.optim.AdamW | |
optimizer = optimizer_cls( | |
unet.parameters(), | |
lr=args.lr, | |
betas=(args.adam_beta1, args.adam_beta2), | |
eps=args.adam_epsilon, | |
weight_decay=args.adam_weight_decay, | |
) | |
noise_scheduler = DDPMScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule='scaled_linear', | |
num_train_timesteps=1000, | |
) | |
#TODO: put arguments | |
def trainDataloader(): | |
# load dataset | |
store = ImageStore(directoryToExtract) | |
dataset = AspectDataset(store, tokenizer) | |
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) | |
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank) | |
print(f'STORE_LEN: {len(store)}') | |
if args.output_bucket_info: | |
print(bucket.get_bucket_info()) | |
exit(0) | |
train_dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_sampler=sampler, | |
num_workers=0, | |
collate_fn=dataset.collate_fn | |
) | |
return train_dataloader | |
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) | |
# create ema | |
if args.use_ema: | |
ema_unet = EMAModel(unet.parameters()) | |
print(get_gpu_ram()) | |
global_step = 0 | |
if args.resume: | |
target_global_step = int(args.resume.split('_')[-1]) | |
print(f'resuming from {args.resume}...') | |
#LR SCHEDULER MOVED TO BE SET IF HIVEMIND DISABLED | |
# lr_scheduler = get_scheduler( | |
# args.lr_scheduler, | |
# optimizer=optimizer, | |
# num_warmup_steps=int(args.lr_scheduler_warmup * num_steps_per_epoch * args.epochs), | |
# num_training_steps=args.epochs * num_steps_per_epoch, | |
# #last_epoch=(global_step // num_steps_per_epoch) - 1, | |
# ) | |
#probably unnecessary but ok | |
def gt(): | |
return(time.time_ns()) | |
def save_checkpoint(global_step): | |
if rank == 0: | |
if args.use_ema: | |
ema_unet.store(unet.parameters()) | |
ema_unet.copy_to(unet.parameters()) | |
pipeline = StableDiffusionPipeline( | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet, | |
tokenizer=tokenizer, | |
scheduler=PNDMScheduler( | |
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | |
), | |
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), | |
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | |
) | |
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}') | |
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}') | |
if args.use_ema: | |
ema_unet.restore(unet.parameters()) | |
# barrier | |
torch.distributed.barrier() | |
# train! | |
#forget about local training, use WD instead | |
finalOptimizer = hivemindWorker(optimizer, args.peers) | |
datasetRunCount = 0 | |
try: | |
while True: | |
recipt = onlineGather(datasetServer=datasetServer, wantedImages=wantedImages, directoryToExtract=directoryToExtract) | |
#Reload Dataset | |
print("Reloading Dataset...") | |
train_dataloader = trainDataloader() | |
num_steps_per_epoch = len(train_dataloader) | |
progress_bar = tqdm.tqdm(range(num_steps_per_epoch), desc="Total Steps", leave=False) | |
loss = torch.tensor(0.0, device=device, dtype=weight_dtype) | |
unet.train() | |
for _, batch in enumerate(train_dataloader): | |
if args.resume and global_step < target_global_step: | |
if rank == 0: | |
progress_bar.update(1) | |
global_step += 1 | |
continue | |
b_start = time.perf_counter() | |
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() | |
latents = latents * 0.18215 | |
# Sample noise | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
# Sample a random timestep for each image | |
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) | |
timesteps = timesteps.long() | |
# Add noise to the latents according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
# Get the text embedding for conditioning | |
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True) | |
if args.clip_penultimate: | |
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2]) | |
else: | |
encoder_hidden_states = encoder_hidden_states.last_hidden_state | |
# Predict the noise residual and compute loss | |
with torch.autocast('cuda', enabled=args.fp16): | |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | |
# Backprop and all reduce | |
scaler.scale(loss).backward() | |
scaler.step(finalOptimizer) | |
scaler.update() | |
finalOptimizer.step() | |
finalOptimizer.zero_grad() | |
# Update EMA | |
if args.use_ema: | |
ema_unet.step(unet.parameters()) | |
# perf | |
b_end = time.perf_counter() | |
seconds_per_step = b_end - b_start | |
steps_per_second = 1 / seconds_per_step | |
rank_images_per_second = args.batch_size * steps_per_second | |
world_images_per_second = rank_images_per_second * world_size | |
samples_seen = global_step * args.batch_size * world_size | |
# All reduce loss | |
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) | |
if rank == 0: | |
progress_bar.update(1) | |
global_step += 1 | |
logs = { | |
"train/loss": loss.detach().item() / world_size, | |
"train/datasetRunCount": datasetRunCount, | |
"train/step": global_step, | |
"train/samples_seen": samples_seen, | |
"perf/rank_samples_per_second": rank_images_per_second, | |
"perf/global_samples_per_second": world_images_per_second, | |
} | |
progress_bar.set_postfix(logs) | |
run.log(logs, step=global_step) | |
if global_step % args.save_steps == 0: | |
save_checkpoint(global_step) | |
if enableinference: | |
if global_step % args.image_log_steps == 0: | |
if rank == 0: | |
# get prompt from random batch | |
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) | |
if args.image_log_scheduler == 'DDIMScheduler': | |
print('using DDIMScheduler scheduler') | |
scheduler = DDIMScheduler( | |
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | |
) | |
else: | |
print('using PNDMScheduler scheduler') | |
scheduler=PNDMScheduler( | |
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | |
) | |
pipeline = StableDiffusionPipeline( | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet, | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
safety_checker=None, # disable safety checker to save memory | |
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | |
).to(device) | |
# inference | |
if enablewandb: | |
images = [] | |
else: | |
saveInferencePath = args.output_path + "/inference" | |
os.makedirs(saveInferencePath, exist_ok=True) | |
with torch.no_grad(): | |
with torch.autocast('cuda', enabled=args.fp16): | |
for _ in range(args.image_log_amount): | |
if enablewandb: | |
images.append( | |
wandb.Image(pipeline( | |
prompt, num_inference_steps=args.image_log_inference_steps | |
).images[0], | |
caption=prompt) | |
) | |
else: | |
from datetime import datetime | |
images = pipeline(prompt, num_inference_steps=args.image_log_inference_steps).images[0] | |
filenameImg = str(time.time_ns()) + ".png" | |
filenameTxt = str(time.time_ns()) + ".txt" | |
images.save(saveInferencePath + "/" + filenameImg) | |
with open(saveInferencePath + "/" + filenameTxt, 'a') as f: | |
f.write('Used prompt: ' + prompt + '\n') | |
f.write('Generated Image Filename: ' + filenameImg + '\n') | |
f.write('Generated at: ' + str(global_step) + ' steps' + '\n') | |
f.write('Generated at: ' + str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))+ '\n') | |
# log images under single caption | |
if enablewandb: | |
run.log({'images': images}, step=global_step) | |
# cleanup so we don't run out of memory | |
del pipeline | |
gc.collect() | |
torch.distributed.barrier() | |
print('Did one dataset run. Reporting...') | |
reportStatus = onlineReport(datasetServer=datasetServer, recipt=recipt) | |
if reportStatus is True: | |
print("Report Success") | |
else: | |
print("Report failed, exiting...") | |
exit() | |
print("Cleaning folder...") | |
shutil.rmtree(workingDirectory + "/tmp") | |
except Exception as e: | |
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}') | |
pass | |
except KeyboardInterrupt: | |
print("Quitting...") | |
print("Saving checkpoint...") | |
save_checkpoint(global_step) | |
print("Checkpoint Saved.") | |
if __name__ == "__main__": | |
setup() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment