Last active
July 29, 2023 20:15
-
-
Save chavinlo/c30c16920903821e582895b2558c6874 to your computer and use it in GitHub Desktop.
HIVEMIND_DISTRIBUTED_TRAINING_12_9_2022_8_37PM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Install bitsandbytes: | |
# `nvcc --version` to get CUDA version. | |
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. | |
# Example Usage: | |
# Single GPU: torchrun --nproc_per_node=1 trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=1 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True | |
# Multiple GPUs: torchrun --nproc_per_node=N trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True | |
import argparse | |
import socket | |
import sys | |
import zipfile | |
import torch | |
import torchvision | |
import transformers | |
import diffusers | |
import os | |
import glob | |
import random | |
import tqdm | |
import resource | |
import psutil | |
import pynvml | |
import wandb | |
import gc | |
import time | |
import itertools | |
import numpy as np | |
import json | |
import re | |
import traceback | |
import shutil | |
import requests | |
import hivemind | |
import ipaddress | |
from typing import Optional | |
from functools import reduce | |
try: | |
pynvml.nvmlInit() | |
except pynvml.nvml.NVMLError_LibraryNotFound: | |
pynvml = None | |
from typing import Iterable | |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline | |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
from diffusers.optimization import get_scheduler | |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
from PIL import Image, ImageOps | |
from PIL.Image import Image as Img | |
from typing import Generator, Tuple | |
torch.backends.cuda.matmul.allow_tf32 = True | |
from omegaconf import OmegaConf | |
from hivemind import Float16Compression | |
from threading import Thread | |
import logging | |
parser = argparse.ArgumentParser(description="Hivemind Trainer") | |
parser.add_argument('-c', '--config', type=str, default="configuration.yaml", required=True, help="Path to the configuration YAML file") | |
#TODO: change this to integers | |
parser.add_argument('-l', '--loglevel', type=str, default="INFO", help="Loglvel for logging. https://docs.python.org/3/library/logging.html") | |
args = parser.parse_args() | |
logging.basicConfig(level="INFO") | |
conf = OmegaConf.load(args.config) | |
temporary_dataset = os.path.join(conf.local.working_path, "dataset") | |
cookies = { | |
'nickname': conf.local.iden.nickname | |
} | |
# def setup(): | |
# torch.distributed.init_process_group("nccl", init_method="env://") | |
# def cleanup(): | |
# torch.distributed.destroy_process_group() | |
# def get_rank() -> int: | |
# if not torch.distributed.is_initialized(): | |
# return 0 | |
# return torch.distributed.get_rank() | |
# def get_world_size() -> int: | |
# if not torch.distributed.is_initialized(): | |
# return 1 | |
# return torch.distributed.get_world_size() | |
def get_gpu_ram() -> str: | |
""" | |
Returns memory usage statistics for the CPU, GPU, and Torch. | |
:return: | |
""" | |
gpu_str = "" | |
torch_str = "" | |
try: | |
cudadev = torch.cuda.current_device() | |
nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev) | |
gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device) | |
gpu_total = int(gpu_info.total / 1E6) | |
gpu_free = int(gpu_info.free / 1E6) | |
gpu_used = int(gpu_info.used / 1E6) | |
gpu_str = f"GPU: (U: {gpu_used:,}mb F: {gpu_free:,}mb " \ | |
f"T: {gpu_total:,}mb) " | |
torch_reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1E6) | |
torch_reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1E6) | |
torch_used_gpu = int(torch.cuda.memory_allocated() / 1E6) | |
torch_max_used_gpu = int(torch.cuda.max_memory_allocated() / 1E6) | |
torch_str = f"TORCH: (R: {torch_reserved_gpu:,}mb/" \ | |
f"{torch_reserved_max:,}mb, " \ | |
f"A: {torch_used_gpu:,}mb/{torch_max_used_gpu:,}mb)" | |
except AssertionError: | |
pass | |
cpu_maxrss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1E3 + | |
resource.getrusage( | |
resource.RUSAGE_CHILDREN).ru_maxrss / 1E3) | |
cpu_vmem = psutil.virtual_memory() | |
cpu_free = int(cpu_vmem.free / 1E6) | |
return f"CPU: (maxrss: {cpu_maxrss:,}mb F: {cpu_free:,}mb) " \ | |
f"{gpu_str}" \ | |
f"{torch_str}" | |
class Validation(): | |
def __init__(self, is_skipped: bool, is_extended: bool) -> None: | |
if is_skipped: | |
self.validate = self.__no_op | |
return print("Validation: Skipped") | |
if is_extended: | |
self.validate = self.__extended_validate | |
return print("Validation: Extended") | |
self.validate = self.__validate | |
print("Validation: Standard") | |
def __validate(self, fp: str) -> bool: | |
try: | |
Image.open(fp) | |
return True | |
except: | |
print(f'WARNING: Image cannot be opened: {fp}') | |
return False | |
def __extended_validate(self, fp: str) -> bool: | |
try: | |
Image.open(fp).load() | |
return True | |
except (OSError) as error: | |
if 'truncated' in str(error): | |
print(f'WARNING: Image truncated: {error}') | |
return False | |
print(f'WARNING: Image cannot be opened: {error}') | |
return False | |
except: | |
print(f'WARNING: Image cannot be opened: {error}') | |
return False | |
def __no_op(self, fp: str) -> bool: | |
return True | |
class Resize(): | |
def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None: | |
if not is_resizing: | |
self.resize = self.__no_op | |
return | |
if not is_not_migrating: | |
self.resize = self.__migration | |
dataset_path = os.path.split(temporary_dataset) | |
self.__directory = os.path.join( | |
dataset_path[0], | |
f'{dataset_path[1]}_cropped' | |
) | |
os.makedirs(self.__directory, exist_ok=True) | |
return print(f"Resizing: Performing migration to '{self.__directory}'.") | |
self.resize = self.__no_migration | |
def __no_migration(self, image_path: str, w: int, h: int) -> Img: | |
return ImageOps.fit( | |
Image.open(image_path), | |
(w, h), | |
bleed=0.0, | |
centering=(0.5, 0.5), | |
method=Image.Resampling.LANCZOS | |
).convert(mode='RGB') | |
def __migration(self, image_path: str, w: int, h: int) -> Img: | |
filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1]) | |
image = ImageOps.fit( | |
Image.open(image_path), | |
(w, h), | |
bleed=0.0, | |
centering=(0.5, 0.5), | |
method=Image.Resampling.LANCZOS | |
).convert(mode='RGB') | |
image.save( | |
os.path.join(f'{self.__directory}', f'{filename}.jpg'), | |
optimize=True | |
) | |
try: | |
shutil.copy( | |
os.path.join(temporary_dataset, f'{filename}.txt'), | |
os.path.join(self.__directory, f'{filename}.txt'), | |
follow_symlinks=False | |
) | |
except (FileNotFoundError): | |
f = open( | |
os.path.join(self.__directory, f'{filename}.txt'), | |
'w', | |
encoding='UTF-8' | |
) | |
f.close() | |
return image | |
def __no_op(self, image_path: str, w: int, h: int) -> Img: | |
return Image.open(image_path) | |
class ImageStore: | |
def __init__(self, data_dir: str) -> None: | |
self.data_dir = data_dir | |
self.image_files = [] | |
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] | |
self.validator = Validation( | |
conf.local.image_store.skip, | |
conf.local.image_store.extended | |
).validate | |
self.resizer = Resize(conf.local.image_store.resize, conf.local.image_store.no_migration).resize | |
self.image_files = [x for x in self.image_files if self.validator(x)] | |
def __len__(self) -> int: | |
return len(self.image_files) | |
# iterator returns images as PIL images and their index in the store | |
def __iter__(self) -> Generator[Tuple[Img, int], None, None]: | |
for i, f in enumerate(self.image_files): | |
yield Image.open(f), i | |
# get image by index | |
def get_image(self, ref: Tuple[int, int, int]) -> Img: | |
return self.resizer( | |
self.image_files[ref[0]], | |
ref[1], | |
ref[2] | |
) | |
# gets caption by removing the extension from the filename and replacing it with .txt | |
def get_caption(self, ref: Tuple[int, int, int]) -> str: | |
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt' | |
with open(filename, 'r', encoding='UTF-8') as f: | |
return f.read() | |
# for confused questions <[email protected]> | |
# or via discord <lopho#5445> | |
class SimpleBucket(torch.utils.data.Sampler): | |
""" | |
Batches samples into buckets of same size. | |
""" | |
def __init__(self, | |
store: ImageStore, | |
batch_size: int, | |
shuffle: bool = True, | |
num_replicas: int = 1, | |
rank: int = 0, | |
resize: bool = False, | |
image_side_divisor: int = 64, | |
max_image_area: int = 512 ** 2, | |
image_side_min: Optional[int] = None, | |
image_side_max: Optional[int] = None, | |
fixed_size: Optional[tuple[int, int]] = None | |
): | |
super().__init__(None) | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.store = store | |
self.buckets = dict() | |
self.ratios = [] | |
if resize: | |
m = image_side_divisor | |
assert (max_image_area // m) == max_image_area / m, "resolution not multiple of divisor" | |
if image_side_max is not None: | |
assert (image_side_max // m) == image_side_max / m, "side not multiple of divisor" | |
if image_side_min is not None: | |
assert (image_side_min // m) == image_side_min / m, "side not multiple of divisor" | |
if fixed_size is not None: | |
assert (fixed_size[0] // m) == fixed_size[0] / m, "side not multiple of divisor" | |
assert (fixed_size[1] // m) == fixed_size[1] / m, "side not multiple of divisor" | |
if image_side_min is None: | |
if image_side_max is None: | |
image_side_min = m | |
else: | |
image_side_min = max((max_image_area // image_side_max) * m, m) | |
if image_side_max is None: | |
image_side_max = max((max_image_area // image_side_min) * m, m) | |
self.fixed_size = fixed_size | |
self.image_side_min = image_side_min | |
self.image_side_max = image_side_max | |
self.image_side_divisor = image_side_divisor | |
self.max_image_area = max_image_area | |
self.dropped_samples = [] | |
self.init_buckets(resize) | |
self.num_replicas = num_replicas | |
self.rank = rank | |
def __iter__(self): | |
# generate batches | |
batches = [] | |
for b in self.buckets: | |
idxs = self.buckets[b] | |
if self.shuffle: | |
random.shuffle(idxs) | |
rest = len(idxs) % self.batch_size | |
idxs = idxs[rest:] | |
batched_idxs = [idxs[i:i + self.batch_size] for i in range(0, len(idxs), self.batch_size)] | |
for bidx in batched_idxs: | |
batches.append([(idx, b[0], b[1]) for idx in bidx]) | |
if self.shuffle: | |
random.shuffle(batches) | |
return iter(batches[self.rank::self.num_replicas]) | |
def __len__(self): | |
return self.get_batch_count() // self.num_replicas | |
def get_batch_count(self) -> int: | |
return reduce(lambda x, y: x + len(y) // self.batch_size, self.buckets.values(), 0) | |
def _fit_image_size(self, w, h): | |
if self.fixed_size is not None: | |
return self.fixed_size | |
max_area = self.max_image_area | |
scale = (max_area / (w * h)) ** 0.5 | |
m = self.image_side_divisor | |
w2 = round((w * scale) / m) * m | |
h2 = round((h * scale) / m) * m | |
if w2 * h2 > max_area: # top end can round over limits | |
w = int((w * scale) / m) * m | |
h = int((h * scale) / m) * m | |
else: | |
w = w2 | |
h = h2 | |
w = min(max(w, self.image_side_min), self.image_side_max) | |
h = min(max(h, self.image_side_min), self.image_side_max) | |
return w, h | |
def init_buckets(self, resize = False): | |
# create buckets | |
buckets = {} | |
for img, idx in tqdm.tqdm(self.store, desc='Bucketing', dynamic_ncols=True): | |
key = img.size | |
img.close() | |
if resize: | |
key = self._fit_image_size(*key) | |
buckets.setdefault(key, []).append(idx) | |
# fit buckets < batch_size in closest bucket if resizing is enabled | |
if resize: | |
for b in buckets: | |
if len(buckets[b]) < self.batch_size: | |
# find closest bucket | |
best_fit = float('inf') | |
best_bucket = None | |
for ob in buckets: | |
if ob == b or len(buckets[ob]) == 0: | |
continue | |
d = abs(ob[0] - b[0]) + abs(ob[1] - b[1]) | |
if d < best_fit: | |
best_fit = d | |
best_bucket = ob | |
if best_bucket is not None: | |
buckets[best_bucket].extend(buckets[b]) | |
buckets[b].clear() | |
# drop buckets < batch_size | |
for b in list(buckets.keys()): | |
if len(buckets[b]) < self.batch_size: | |
self.dropped_samples += buckets.pop(b) | |
else: | |
self.ratios.append(b[0] / b[1]) | |
self.buckets = buckets | |
def get_bucket_info(self): | |
return json.dumps({ | |
"buckets": list(self.buckets.keys()), | |
"ratios": self.ratios | |
}) | |
class AspectDataset(torch.utils.data.Dataset): | |
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, device: torch.device, ucg: float = 0.1): | |
self.store = store | |
self.tokenizer = tokenizer | |
self.text_encoder = text_encoder | |
self.device = device | |
self.ucg = ucg | |
#if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel: | |
# self.text_encoder = self.text_encoder.module | |
self.transforms = torchvision.transforms.Compose([ | |
torchvision.transforms.RandomHorizontalFlip(p=0.5), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize([0.5], [0.5]) | |
]) | |
def __len__(self): | |
return len(self.store) | |
def __getitem__(self, item: Tuple[int, int, int]): | |
return_dict = {'pixel_values': None, 'input_ids': None} | |
image_file = self.store.get_image(item) | |
return_dict['pixel_values'] = self.transforms(image_file) | |
if random.random() > self.ucg: | |
caption_file = self.store.get_caption(item) | |
else: | |
caption_file = '' | |
return_dict['input_ids'] = caption_file | |
return return_dict | |
def collate_fn(self, examples): | |
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None]) | |
pixel_values.to(memory_format=torch.contiguous_format).float() | |
if conf.everyone.extended_chunks < 2: | |
max_length = self.tokenizer.model_max_length - 2 | |
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=max_length).input_ids for example in examples if example is not None] | |
else: | |
max_length = self.tokenizer.model_max_length | |
max_chunks = conf.everyone.extended_chunks | |
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None] | |
tokens = input_ids | |
if conf.everyone.extended_chunks < 2: | |
for i, x in enumerate(input_ids): | |
for j, y in enumerate(x): | |
input_ids[i][j] = [self.tokenizer.bos_token_id, *y, *np.full((self.tokenizer.model_max_length - len(y) - 1), self.tokenizer.eos_token_id)] | |
if conf.everyone.clip_penultimate: | |
input_ids = [self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True)['hidden_states'][-2])[0] for input_id in input_ids] | |
else: | |
input_ids = [self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True).last_hidden_state[0] for input_id in input_ids] | |
else: | |
max_standard_tokens = max_length - 2 | |
max_chunks = conf.everyone.extended_chunks | |
max_len = np.ceil(max(len(x) for x in input_ids) / max_standard_tokens).astype(int).item() * max_standard_tokens | |
if max_len > max_standard_tokens: | |
z = None | |
for i, x in enumerate(input_ids): | |
if len(x) < max_len: | |
input_ids[i] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)] | |
batch_t = torch.tensor(input_ids) | |
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)] | |
for chunk in chunks: | |
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1) | |
if z is None: | |
if conf.everyone.clip_penultimate: | |
z = self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2]) | |
else: | |
z = self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state | |
else: | |
if conf.everyone.clip_penultimate: | |
z = torch.cat((z, self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])), dim=-2) | |
else: | |
z = torch.cat((z, self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state), dim=-2) | |
input_ids = z | |
else: | |
for i, x in enumerate(input_ids): | |
input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.tokenizer.model_max_length - len(x) - 1), self.tokenizer.eos_token_id)] | |
if conf.everyone.clip_penultimate: | |
input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2]) | |
else: | |
input_ids = self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True).last_hidden_state | |
input_ids = torch.stack(tuple(input_ids)) | |
return { | |
'pixel_values': pixel_values, | |
'input_ids': input_ids, | |
'tokens': tokens | |
} | |
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | |
class EMAModel: | |
""" | |
Exponential Moving Average of models weights | |
""" | |
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | |
parameters = list(parameters) | |
self.shadow_params = [p.clone().detach() for p in parameters] | |
self.decay = decay | |
self.optimization_step = 0 | |
def get_decay(self, optimization_step): | |
""" | |
Compute the decay factor for the exponential moving average. | |
""" | |
value = (1 + optimization_step) / (10 + optimization_step) | |
return 1 - min(self.decay, value) | |
@torch.no_grad() | |
def step(self, parameters): | |
parameters = list(parameters) | |
self.optimization_step += 1 | |
self.decay = self.get_decay(self.optimization_step) | |
for s_param, param in zip(self.shadow_params, parameters): | |
if param.requires_grad: | |
tmp = self.decay * (s_param - param) | |
s_param.sub_(tmp) | |
else: | |
s_param.copy_(param) | |
torch.cuda.empty_cache() | |
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | |
""" | |
Copy current averaged parameters into given collection of parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored moving averages. If `None`, the | |
parameters with which this `ExponentialMovingAverage` was | |
initialized will be used. | |
""" | |
parameters = list(parameters) | |
for s_param, param in zip(self.shadow_params, parameters): | |
param.data.copy_(s_param.data) | |
# From CompVis LitEMA implementation | |
def store(self, parameters): | |
""" | |
Save the current parameters for restoring later. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
temporarily stored. | |
""" | |
self.collected_params = [param.clone() for param in parameters] | |
def restore(self, parameters): | |
""" | |
Restore the parameters stored with the `store` method. | |
Useful to validate the model with EMA parameters without affecting the | |
original optimization process. Store the parameters before the | |
`copy_to` method. After validation (or model saving), use this to | |
restore the former parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored parameters. | |
""" | |
for c_param, param in zip(self.collected_params, parameters): | |
param.data.copy_(c_param.data) | |
del self.collected_params | |
gc.collect() | |
def to(self, device=None, dtype=None) -> None: | |
r"""Move internal buffers of the ExponentialMovingAverage to `device`. | |
Args: | |
device: like `device` argument to `torch.Tensor.to` | |
""" | |
# .to() on the tensors handles None correctly | |
self.shadow_params = [ | |
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | |
for p in self.shadow_params | |
] | |
def backgroundreport(url, data): | |
requests.post(url, json=data) | |
def setuphivemind(): | |
if os.path.exists(conf.local.working_path): | |
shutil.rmtree(conf.local.working_path) | |
os.makedirs(conf.local.working_path) | |
if requests.get('http://' + conf.everyone.server + '/info').status_code == 200: | |
print("Connection Success") | |
serverconfig = json.loads(requests.get('http://' + conf.everyone.server + '/config').content) | |
print(serverconfig) | |
imgs_per_epoch = int(serverconfig["ImagesPerEpoch"]) | |
total_epochs = int(serverconfig["Epochs"]) | |
return(imgs_per_epoch, total_epochs) | |
else: | |
raise ConnectionError("Unable to connect to server") | |
def getchunk(server, amount): | |
if os.path.isdir(temporary_dataset): | |
shutil.rmtree(temporary_dataset) | |
os.mkdir(temporary_dataset) | |
serverdomain = 'http://' + server | |
rtasks_url = serverdomain + '/v1/get/tasks/' + str(amount) | |
rtasks = requests.get(rtasks_url).json() | |
print("Downloading Files") | |
pfiles = requests.post(serverdomain + '/v1/get/files', json=rtasks) | |
tmpZip = conf.local.working_path + '/tmp.zip' | |
open(tmpZip, 'wb').write(pfiles.content) | |
zipfile.ZipFile(tmpZip, 'r').extractall(temporary_dataset) | |
os.remove(tmpZip) | |
return(rtasks) | |
def report(server, tasks): | |
preport = requests.post('http://' + server + '/v1/post/epochcount', json=tasks) | |
if preport.status_code == 200: | |
return True | |
else: | |
return False | |
def dataloader(tokenizer, text_encoder, device, world_size, rank): | |
# load dataset | |
store = ImageStore(temporary_dataset) | |
dataset = AspectDataset(store, tokenizer, text_encoder, device, ucg=float(conf.everyone.ucg)) | |
sampler = SimpleBucket( | |
store = store, | |
batch_size = conf.local.batch_size, | |
shuffle = conf.advanced.buckets.shuffle, | |
resize = conf.local.image_store.resize, | |
image_side_min = conf.advanced.buckets.side_min, | |
image_side_max = conf.advanced.buckets.side_max, | |
image_side_divisor = 64, | |
max_image_area = conf.everyone.resolution ** 2, | |
num_replicas = world_size, | |
rank = rank | |
) | |
print(f'STORE_LEN: {len(store)}') | |
# if args.output_bucket_info: | |
# print(sampler.get_bucket_info()) | |
train_dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_sampler=sampler, | |
num_workers=0, | |
collate_fn=dataset.collate_fn | |
) | |
# # Migrate dataset | |
# if args.resize and not args.no_migration: | |
# for _, batch in enumerate(train_dataloader): | |
# continue | |
# print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") | |
# exit(0) | |
return train_dataloader | |
def main(): | |
rank = 0 | |
# world_size = get_world_size() | |
torch.cuda.set_device(rank) | |
if rank == 0: | |
os.makedirs(conf.local.output_path, exist_ok=True) | |
mode = 'disabled' | |
if conf.local.wandb: | |
mode = 'online' | |
if conf.local.hf_token is not None: | |
os.environ['HF_API_TOKEN'] = conf.local.hf_token | |
conf.local.hf_token = None | |
run = wandb.init(project=conf.everyone.project_name, name=conf.everyone.project_name, config=vars(args), dir=conf.local.output_path+'/wandb', mode=mode) | |
# Inform the user of host, and various versions -- useful for debugging issues. | |
print("RUN_NAME:", conf.everyone.project_name) | |
print("HOST:", socket.gethostname()) | |
print("CUDA:", torch.version.cuda) | |
print("TORCH:", torch.__version__) | |
print("TRANSFORMERS:", transformers.__version__) | |
print("DIFFUSERS:", diffusers.__version__) | |
print("MODEL:", conf.everyone.model) | |
print("FP16:", conf.everyone.fp16) | |
print("RESOLUTION:", conf.everyone.resolution) | |
if conf.local.hf_token is not None: | |
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') | |
else: | |
try: | |
conf.local.hf_token = os.environ['HF_API_TOKEN'] | |
print("HF Token set via enviroment variable") | |
except Exception: | |
print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)") | |
conf.local.hf_token = "none" | |
device = torch.device('cuda') | |
print("DEVICE:", device) | |
# setup fp16 stuff | |
scaler = torch.cuda.amp.GradScaler(enabled=conf.everyone.fp16) | |
# Set seed | |
torch.manual_seed(conf.everyone.seed) | |
random.seed(conf.everyone.seed) | |
np.random.seed(conf.everyone.seed) | |
print('RANDOM SEED:', conf.everyone.seed) | |
tokenizer = CLIPTokenizer.from_pretrained(conf.everyone.model, subfolder='tokenizer', use_auth_token=conf.local.hf_token) | |
text_encoder = CLIPTextModel.from_pretrained(conf.everyone.model, subfolder='text_encoder', use_auth_token=conf.local.hf_token) | |
vae = AutoencoderKL.from_pretrained(conf.everyone.model, subfolder='vae', use_auth_token=conf.local.hf_token) | |
unet = UNet2DConditionModel.from_pretrained(conf.everyone.model, subfolder='unet', use_auth_token=conf.local.hf_token) | |
# Freeze vae and text_encoder | |
vae.requires_grad_(False) | |
if not conf.everyone.train_text_encoder: | |
text_encoder.requires_grad_(False) | |
if conf.local.gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
if conf.everyone.train_text_encoder: | |
text_encoder.gradient_checkpointing_enable() | |
if conf.local.xformers: | |
unet.set_use_memory_efficient_attention_xformers(True) | |
# "The “safer” approach would be to move the model to the device first and create the optimizer afterwards." | |
weight_dtype = torch.float16 if conf.everyone.fp16 else torch.float32 | |
# move models to device | |
vae = vae.to(device, dtype=weight_dtype) | |
unet = unet.to(device, dtype=torch.float32) | |
text_encoder = text_encoder.to(device, dtype=weight_dtype if not conf.everyone.train_text_encoder else torch.float32) | |
# unet = torch.nn.parallel.DistributedDataParallel( | |
# unet, | |
# device_ids=[rank], | |
# output_device=rank, | |
# gradient_as_bucket_view=True | |
# ) | |
# if conf.everyone.train_text_encoder: | |
# text_encoder = torch.nn.parallel.DistributedDataParallel( | |
# text_encoder, | |
# device_ids=[rank], | |
# output_device=rank, | |
# gradient_as_bucket_view=True | |
# ) | |
if conf.local.bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. | |
try: | |
import bitsandbytes as bnb | |
optimizer_cls = bnb.optim.AdamW8bit | |
except: | |
print('bitsandbytes not supported, using regular Adam optimizer') | |
optimizer_cls = torch.optim.AdamW | |
else: | |
optimizer_cls = torch.optim.AdamW | |
""" | |
optimizer = optimizer_cls( | |
unet.parameters(), | |
lr=args.lr, | |
betas=(args.adam_beta1, args.adam_beta2), | |
eps=args.adam_epsilon, | |
weight_decay=args.adam_weight_decay, | |
) | |
""" | |
optimizer_parameters = unet.parameters() if not conf.everyone.train_text_encoder else itertools.chain(unet.parameters(), text_encoder.parameters()) | |
# Create distributed optimizer | |
#from torch.distributed.optim import ZeroRedundancyOptimizer | |
#we changed to cls for single gpu training | |
tmp_optimizer = optimizer_cls( | |
optimizer_parameters, | |
# optimizer_class=optimizer_cls, | |
# parameters_as_bucket_view=True, | |
lr=float(conf.everyone.lr), | |
betas=(float(conf.advanced.opt.betas.one), float(conf.advanced.opt.betas.two)), | |
eps=float(conf.advanced.opt.epsilon), | |
weight_decay=float(conf.advanced.opt.weight_decay), | |
) | |
noise_scheduler = DDPMScheduler.from_pretrained( | |
conf.everyone.model, | |
subfolder='scheduler', | |
use_auth_token=conf.local.hf_token, | |
) | |
# Hivemind Setup | |
# get network peers (if mother peer then ignore) | |
rmaddrs_rq = requests.get('http://' + conf.everyone.server + "/v1/get/peers") | |
if rmaddrs_rq.status_code == 200: | |
peer_list = json.loads(rmaddrs_rq.content) | |
else: | |
raise ConnectionError("Unable to obtain peers from server") | |
# set local maddrs ports | |
host_maddrs_tcp = "/ip4/0.0.0.0/tcp/" + str(conf.local.networking.internal.tcp) | |
host_maddrs_udp = "/ip4/0.0.0.0/udp/" + str(conf.local.networking.internal.udp) + "/quic" | |
# set public to-be-announced maddrs | |
# get public ip | |
if conf.local.networking.external.ip == "": | |
conf.local.networking.external.ip = None | |
if conf.local.networking.external.ip == "auto" or conf.local.networking.external.ip is None: | |
completed = False | |
if completed is False: | |
try: | |
ip = requests.get("https://api.ipify.org/", timeout=5).text | |
ipsrc = "online" | |
completed = True | |
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as err: | |
print("Ipfy.org took too long, trying another domain.") | |
if completed is False: | |
try: | |
ip = requests.get("https://ipv4.icanhazip.com/", timeout=5).text | |
ipsrc = "online" | |
completed = True | |
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as err: | |
print("Icanhazip.com took too long, trying another domain.") | |
if completed is False: | |
try: | |
tmpjson = json.loads(requests.get("https://jsonip.com/", timeout=5).content) | |
ip = tmpjson["ip"] | |
ipsrc = "online" | |
completed = True | |
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as err: | |
print("Jsonip.com took too long, ran out of alternatives.") | |
raise(ConnectionError) | |
else: | |
ip = conf.local.networking.external.ip | |
ipsrc = "config" | |
#check if valid ip | |
try: | |
ip = ipaddress.ip_address(ip) | |
ip = str(ip) | |
except Exception: | |
raise ValueError("Invalid IP, please check the configuration file. IP Source: " + ipsrc) | |
public_maddrs_tcp = "/ip4/" + ip + "/tcp/" + str(conf.local.networking.external.tcp) | |
public_maddrs_udp = "/ip4/" + ip + "/udp/" + str(conf.local.networking.external.udp) + "/quic" | |
#init dht | |
#TODO: add announce_maddrs | |
dht = hivemind.DHT( | |
host_maddrs=[host_maddrs_tcp, host_maddrs_udp], | |
initial_peers=peer_list, | |
start=True, | |
announce_maddrs=[public_maddrs_tcp, public_maddrs_udp] | |
) | |
#set compression and optimizer | |
compression = Float16Compression() | |
lr_scheduler = get_scheduler( | |
conf.everyone.lr_scheduler, | |
optimizer=tmp_optimizer, | |
num_warmup_steps=int(float(conf.advanced.lr_scheduler_warmup) * imgs_per_epoch * total_epochs), | |
num_training_steps=total_epochs * imgs_per_epoch, | |
) | |
optimizer = hivemind.Optimizer( | |
dht=dht, | |
run_id="testrun", | |
batch_size_per_step=1, | |
target_batch_size=4000, | |
optimizer=tmp_optimizer, | |
use_local_updates=False, | |
matchmaking_time=260.0, | |
averaging_timeout=1200.0, | |
allreduce_timeout=1200.0, | |
load_state_timeout=1200.0, | |
grad_compression=compression, | |
state_averaging_compression=compression, | |
verbose=True, | |
scheduler=lr_scheduler | |
) | |
print('\n'.join(str(addr) for addr in dht.get_visible_maddrs())) | |
print("Global IP:", hivemind.utils.networking.choose_ip_address(dht.get_visible_maddrs())) | |
#statistics | |
if conf.local.stats.enable: | |
statconfig = {"geoaprox": False, "bandwidth": False, "specs": False} | |
bandwidthstats = {} | |
specs_stats = {} | |
print("Stats enabled") | |
if conf.local.stats.geoaprox: | |
statconfig['geoaprox'] = True | |
if conf.local.stats.bandwidth: | |
statconfig["bandwidth"] = True | |
import speedtest | |
session = speedtest.Speedtest() | |
download = session.download() | |
upload = session.upload() | |
ping = session.results.ping | |
bandwidthstats = {"download": str(download), "upload": str(upload), "ping": str(ping)} | |
if conf.local.stats.specs: | |
statconfig["specs"] = True | |
# GPU | |
# https://docs.nvidia.com/deploy/nvml-api/index.html | |
pynvml.nvmlInit() | |
cudadriver_version = pynvml.nvmlSystemGetCudaDriverVersion() | |
driver_version = pynvml.nvmlSystemGetDriverVersion() | |
NVML_version = pynvml.nvmlSystemGetNVMLVersion() | |
#TODO: Assuming one gpu only | |
cudadev = torch.cuda.current_device() | |
nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev) | |
#psu_info = pynvml.nvmlUnitGetPsuInfo(pynvml.c_nvmlPSUInfo_t.) | |
#temperature_info = pynvml.nvmlUnitGetTemperature(nvml_device) | |
#unit_info = pynvml.nvmlUnitGetUnitInfo(nvml_device) | |
arch_info = pynvml.nvmlDeviceGetArchitecture(nvml_device) | |
brand_info = pynvml.nvmlDeviceGetBrand(nvml_device) | |
#clock_info = pynvml.nvmlDeviceGetClock(nvml_device) | |
#clockinfo_info = pynvml.nvmlDeviceGetClockInfo(nvml_device) | |
#maxclock_info = pynvml.nvmlDeviceGetMaxClockInfo(nvml_device) | |
computemode_info = pynvml.nvmlDeviceGetComputeMode(nvml_device) | |
compute_compatability = pynvml.nvmlDeviceGetCudaComputeCapability(nvml_device) | |
pcie_link_gen = pynvml.nvmlDeviceGetCurrPcieLinkGeneration(nvml_device) | |
pcie_width = pynvml.nvmlDeviceGetCurrPcieLinkWidth(nvml_device) | |
display_active_bool = pynvml.nvmlDeviceGetDisplayActive(nvml_device) | |
#memory_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device) | |
gpu_energy_cons = pynvml.nvmlDeviceGetTotalEnergyConsumption(nvml_device) | |
device_name = pynvml.nvmlDeviceGetName(nvml_device) | |
gpusinfo = { | |
"software": { | |
"CUDA_DRIVER_VERSION": str(cudadriver_version), | |
"NVIDIA_DRIVER_VERSION": str(driver_version), | |
"NVML_VERSION": str(NVML_version), | |
}, | |
"hardware": { | |
"energy": { | |
#"PSU_INFO": psu_info, | |
#"TEMPERATURE_INFO": temperature_info, | |
"ENERGY_CONSUMPTION": str(gpu_energy_cons) | |
}, | |
"info": { | |
#"UNIT_INFO": unit_info, | |
"BRAND_INFO": str(brand_info), | |
"DEV_NAME": str(device_name), | |
"DISPLAY_ACTIVE": str(display_active_bool), | |
"ARCH_INFO": str(arch_info) | |
}, | |
"memory": { | |
"PCIE_LINK_GEN": str(pcie_link_gen), | |
"PCIE_WIDTH": str(pcie_width), | |
#"MEMORY_INFO": memory_info, | |
}, | |
"compute": { | |
#"CLOCK": clock_info, | |
#"CLOCK_INFO": clockinfo_info, | |
#"MAX_CLOCK": maxclock_info, | |
"COMPUTE_MODE": str(computemode_info), | |
"COMPUTE_COMPATABILITY": str(compute_compatability) | |
} | |
} | |
} | |
cpuinfo = {} | |
import cpuinfo | |
cpudict = cpuinfo.get_cpu_info() | |
cpuinfo = { | |
'CPU_ARCH': str(cpudict['arch']), | |
"CPU_HZ_AD": str(cpudict["hz_advertised_friendly"]), | |
"CPU_HZ_AC": str(cpudict["hz_actual_friendly"]), | |
"CPU_BITS": str(cpudict["bits"]), | |
"VENDOR_ID": str(cpudict["vendor_id_raw"]), | |
#"HARDWARE_RAW": cpudict["hardware_raw"], | |
"BRAND_RAW": str(cpudict["brand_raw"]) | |
} | |
specs_stats = {'gpu': gpusinfo, 'cpu': cpuinfo} | |
statsjson = { | |
'python_ver': str(sys.version), | |
'config': statconfig, | |
'bandwidth': bandwidthstats, | |
'specs': specs_stats | |
} | |
print(statsjson) | |
pstats = requests.post('http://' + conf.everyone.server + '/v1/post/stats', json=json.dumps(statsjson)) | |
if pstats.status_code != 200: | |
raise ConnectionError("Failed to report stats") | |
# create ema | |
if conf.everyone.use_ema: | |
ema_unet = EMAModel(unet.parameters()) | |
print(get_gpu_ram()) | |
def save_checkpoint(global_step): | |
if rank == 0: | |
if conf.everyone.use_ema: | |
ema_unet.store(unet.parameters()) | |
ema_unet.copy_to(unet.parameters()) | |
pipeline = StableDiffusionPipeline( | |
text_encoder=text_encoder, #if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module, | |
vae=vae, | |
unet=unet, | |
tokenizer=tokenizer, | |
scheduler=PNDMScheduler.from_pretrained(conf.everyone.model, subfolder="scheduler", use_auth_token=conf.local.hf_token), | |
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), | |
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | |
) | |
print(f'saving checkpoint to: {conf.local.output_path}/{"hivemind"}_{global_step}') | |
pipeline.save_pretrained(f'{conf.local.output_path}/{"hivemind"}_{global_step}') | |
if conf.everyone.use_ema: | |
ema_unet.restore(unet.parameters()) | |
# train! | |
try: | |
already_done_steps = (optimizer.tracker.global_progress.samples_accumulated + (optimizer.tracker.global_progress.epoch * optimizer.target_batch_size)) | |
print("Skipping", already_done_steps, "steps on the LR Scheduler.") | |
for i in range(already_done_steps): | |
lr_scheduler.step() | |
print("Done") | |
loss = torch.tensor(0.0, device=device, dtype=weight_dtype) | |
while True: | |
print(get_gpu_ram()) | |
recipt = getchunk(conf.everyone.server, conf.everyone.imgcount) | |
#Note: we removed worldsize here | |
train_dataloader = dataloader(tokenizer, text_encoder, device, 1, rank) | |
num_steps_per_epoch = len(train_dataloader) | |
progress_bar = tqdm.tqdm(range(num_steps_per_epoch), desc="Total Steps", leave=False) | |
global_step = 0 | |
unet.train() | |
if conf.everyone.train_text_encoder: | |
text_encoder.train() | |
for _, batch in enumerate(train_dataloader): | |
b_start = time.perf_counter() | |
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() | |
latents = latents * 0.18215 | |
# Sample noise | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
# Sample a random timestep for each image | |
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) | |
timesteps = timesteps.long() | |
# Add noise to the latents according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
# Get the embedding for conditioning | |
encoder_hidden_states = batch['input_ids'] | |
if noise_scheduler.config.prediction_type == "epsilon": | |
target = noise | |
elif noise_scheduler.config.prediction_type == "v_prediction": | |
target = noise_scheduler.get_velocity(latents, noise, timesteps) | |
else: | |
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}") | |
if not conf.everyone.train_text_encoder: | |
# Predict the noise residual and compute loss | |
with torch.autocast('cuda', enabled=conf.everyone.fp16): | |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") | |
# backprop and update | |
scaler.scale(loss).backward() | |
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) | |
scaler.step(optimizer) | |
scaler.update() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
else: | |
# Predict the noise residual and compute loss | |
with torch.autocast('cuda', enabled=conf.everyone.fp16): | |
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") | |
# backprop and update | |
scaler.scale(loss).backward() | |
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) | |
torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), 1.0) | |
scaler.step(optimizer) | |
scaler.update() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# Update EMA | |
if conf.everyone.use_ema: | |
ema_unet.step(unet.parameters()) | |
# perf | |
b_end = time.perf_counter() | |
seconds_per_step = b_end - b_start | |
steps_per_second = 1 / seconds_per_step | |
rank_images_per_second = conf.local.batch_size * steps_per_second | |
#world_images_per_second = rank_images_per_second #* world_size | |
samples_seen = global_step * conf.local.batch_size #* world_size | |
# get global loss for logging | |
# torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) | |
loss = loss #/ world_size | |
if rank == 0: | |
progress_bar.update(1) | |
global_step += 1 | |
logs = { | |
"train/loss": loss.detach().item(), | |
"train/lr": lr_scheduler.get_last_lr()[0], | |
"train/epoch": 1, | |
"train/step": global_step, | |
"train/samples_seen": samples_seen, | |
"perf/rank_samples_per_second": rank_images_per_second, | |
#"perf/global_samples_per_second": world_images_per_second, | |
} | |
progress_bar.set_postfix(logs) | |
run.log(logs, step=global_step) | |
# if counter < 5: | |
# counter += 1 | |
# elif counter >= 5: | |
# data = { | |
# "tracker.global_progress": optimizer.tracker.global_progress, | |
# "tracker.local_progress": optimizer.tracker.local_progress, | |
# } | |
# print(data) | |
# counter = 0 | |
#Thread(target=backgroundreport, args=(("http://" + conf.everyone.server + "/v1/post/ping"), "world_images_per_second")).start() | |
if global_step % conf.local.save_steps == 0 and global_step > 0: | |
save_checkpoint(global_step) | |
if conf.local.inference.enable: | |
if global_step % conf.inference.log_steps == 0 and global_step > 0: | |
if rank == 0: | |
# get prompt from random batch | |
prompt = tokenizer.decode(batch['tokens'][random.randint(0, len(batch['tokens'])-1)]) | |
if conf.inference.image_log_scheduler == 'DDIMScheduler': | |
print('using DDIMScheduler scheduler') | |
scheduler = DDIMScheduler.from_pretrained(conf.everyone.model, subfolder="scheduler", use_auth_token=conf.local.hf_token) | |
else: | |
print('using PNDMScheduler scheduler') | |
scheduler=PNDMScheduler.from_pretrained(conf.everyone.model, subfolder="scheduler", use_auth_token=conf.local.hf_token) | |
pipeline = StableDiffusionPipeline( | |
text_encoder=text_encoder, #if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module, | |
vae=vae, | |
unet=unet.module, | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
safety_checker=None, # disable safety checker to save memory | |
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | |
).to(device) | |
# inference | |
if conf.local.wandb: | |
images = [] | |
else: | |
saveInferencePath = conf.local.output_path + "/inference" | |
os.makedirs(saveInferencePath, exist_ok=True) | |
with torch.no_grad(): | |
with torch.autocast('cuda', enabled=conf.everyone.fp16): | |
for _ in range(conf.local.inference.amount): | |
if conf.local.wandb: | |
images.append( | |
wandb.Image(pipeline( | |
prompt, num_inference_steps=conf.local.inference.inference_steps | |
).images[0], | |
caption=prompt) | |
) | |
else: | |
from datetime import datetime | |
images = pipeline(prompt, num_inference_steps=conf.local.inference.inference_steps).images[0] | |
filenameImg = str(time.time_ns()) + ".png" | |
filenameTxt = str(time.time_ns()) + ".txt" | |
images.save(saveInferencePath + "/" + filenameImg) | |
with open(saveInferencePath + "/" + filenameTxt, 'a') as f: | |
f.write('Used prompt: ' + prompt + '\n') | |
f.write('Generated Image Filename: ' + filenameImg + '\n') | |
f.write('Generated at: ' + str(global_step) + ' steps' + '\n') | |
f.write('Generated at: ' + str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))+ '\n') | |
# log images under single caption | |
if conf.local.wandb: | |
run.log({'images': images}, step=global_step) | |
# cleanup so we don't run out of memory | |
del pipeline | |
gc.collect() | |
sreport = report(conf.everyone.server, recipt) | |
if sreport is True: | |
print("Report Success") | |
else: | |
raise ConnectionError("Couldn't report") | |
except Exception as e: | |
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}') | |
pass | |
save_checkpoint(global_step) | |
#cleanup() | |
print(get_gpu_ram()) | |
print('Done!') | |
if __name__ == "__main__": | |
#setup() | |
imgs_per_epoch, total_epochs = setuphivemind() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment