Last active
December 15, 2023 06:44
-
-
Save tysam-code/01c1f13250abf7533638a344d1cf4d2e to your computer and use it in GitHub Desktop.
in dev, TODO fill out description later.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import queue | |
import threading | |
import random | |
import torch | |
import torchvision | |
from torchvision.transforms import v2 | |
############################################# | |
# Testing Hyperparameters # | |
############################################# | |
demo_batchsize = 96 | |
num_demo_layers = 18 | |
image_height = 14800 | |
image_width = 18400 | |
input_image_size = 512 | |
num_warmup_iters = 20 | |
num_timing_iters = 40 | |
dtype = torch.half | |
data_device = 'cpu' | |
device = 'cuda' | |
############################################# | |
# Class Definitions # | |
############################################# | |
class BatchedPreprocessingFunction(): | |
def __init__(self, *args, image_size, **kwargs): | |
super().__init__(*args, **kwargs) | |
# TODO TODO TODO TODO: Potential #BUG: We need to double check here and see if it samples a unique angle for each batch item, or rotates them all the same? Could be a very, very subtle #BUG if so :')))) / :'(((( | |
self.batched_rotate_op = torchvision.transforms.v2.RandomRotation(degrees=180, interpolation=torchvision.transforms.InterpolationMode.BILINEAR) | |
self.flip_lr_chance = .5 | |
self.image_size = image_size | |
def __call__(self, stacked_input_targets, cast_pre_rotate=False): | |
# Assumes input of batchsize, 2, (height=width=image_size+rotate_padding,)*2 | |
# Batch rotate on the GPU using the torchvisionv2 transforms. Assumes it's prepadded so that we can crop out the center to avoid extra surrounding grey space from the rotations | |
# CPU rotate does not support half, so we cast it.... | |
if cast_pre_rotate: | |
stacked_input_targets = stacked_input_targets.float() | |
rotated = self.batched_rotate_op(stacked_input_targets) | |
if cast_pre_rotate: | |
stacked_input_targets = stacked_input_targets.half() | |
# Calculate the padding to remove | |
padding_to_trim = rotated.shape[2] - self.image_size, rotated.shape[3] - self.image_size | |
# Slice the rotated image down the target size | |
rotated = rotated[:, :, padding_to_trim[0]//2:-(padding_to_trim[0]-padding_to_trim[0]//2), padding_to_trim[1]//2:-(padding_to_trim[1]-padding_to_trim[1]//2)] | |
# Batch-sample values between 0. and 1. to calculate left-right flip probabilities. | |
to_flip = (torch.rand((rotated.shape[0],), device=stacked_input_targets.device) < self.flip_lr_chance).view(-1, 1, 1, 1).half() | |
# There may be a more memory-efficient and more straightforward way to do this, but right now we have to flip then select between the flipped values | |
# I can't remember off the top of my head why the stack is necessary, but I think it had something to do with a peculiarity of how torch flipped the tensor, sadgely :'(((( | |
result = to_flip * rotated + (1. - to_flip) * torch.stack((torch.fliplr(rotated[:, 0, :, :]), torch.fliplr(rotated[:, 1, :, :])), dim=1) | |
return result | |
# current, slight #hack for now, we define the data in a scope outside of this class to make the forking easier. | |
class CustomImageDataset(torch.utils.data.Dataset): | |
def __init__(self, rotation_padding, per_image_preprocessing_fn=None): | |
# Don't store the labels in the dataloader, as the fork() command copies all of the memory in it :'((((( | |
# slightly hacky as the this 'rotation_padding' variable is used weirdly in different places, # TODO maybe clean it up a bit... | |
self.rotation_padding = rotation_padding | |
self.per_image_preprocessing_fn = per_image_preprocessing_fn | |
def __len__(self): | |
return 4*demo_batchsize*num_warmup_iters+num_timing_iters | |
def __getitem__(self, idx): | |
# Choses randomly based on label image, so there's likely going to be uneven sampling based upon the amount of training images paired per label image in da dataset.... :'((((((((( | |
label_image, input_images = random.choice(demo_input_label_data_pairs) | |
input_image_randomly_selected = random.choice(input_images) | |
# Can be improved, # TODO note to self, let's replace with a vectorized version in the future please! <3 :')))) | |
inputs_and_targets = slow_cpu_iterative_image_slicer(input_image_randomly_selected, label_image, rotation_padding=self.rotation_padding, chunk_size=input_image_size) ####, batchsize=batchsize) | |
if self.per_image_preprocessing_fn is not None: | |
inputs_and_targets = self.per_image_preprocessing_fn(inputs_and_targets.unsqueeze(0), cast_pre_rotate=True).squeeze() | |
return inputs_and_targets | |
# Co-written and modified w/ ChatGPT 4.0 | |
class GPUPrefetcher: | |
def __init__(self, loader, batched_fn_to_apply, queue_size=12, num_threads=2): | |
self.loader = iter(loader) | |
self.queue = queue.Queue(maxsize=queue_size) # Adjust size as needed | |
self.is_running = True | |
self.threads = [] | |
self.streams = [torch.cuda.Stream() for _ in range(num_threads)] | |
self.batched_fn_to_apply = batched_fn_to_apply | |
for i in range(num_threads): | |
thread = threading.Thread(target=self._prefetch, args=(i,)) | |
thread.daemon = True | |
thread.start() | |
self.threads.append(thread) | |
def _prefetch(self, thread_id): | |
with torch.cuda.stream(self.streams[thread_id]): | |
while self.is_running: | |
try: | |
with torch.no_grad(): | |
input_target_tensor = next(self.loader) | |
cuda_tensor = input_target_tensor.to('cuda', memory_format=torch.channels_last, non_blocking=True) | |
batch_prepped = self.batched_fn_to_apply(cuda_tensor) | |
self.queue.put(batch_prepped) | |
except StopIteration: | |
break | |
def __next__(self): | |
with torch.no_grad(): | |
if not self.is_running and self.queue.empty(): | |
raise StopIteration | |
# Can be helpful to uncomment this below line if'n u wanna watch queue health over the course of your training run (might be helpful to log in other ways if you like doing that as well) | |
####print("current queue depth: ", self.queue.qsize()) | |
for stream in self.streams: | |
torch.cuda.current_stream().wait_stream(stream) | |
next_input_target_tensor = self.queue.get() | |
if next_input_target_tensor is None: | |
raise StopIteration | |
return next_input_target_tensor | |
def __del__(self): | |
self.is_running = False | |
for thread in self.threads: | |
thread.join() | |
def stop(self): | |
self.is_running = False | |
for _ in self.threads: | |
self.queue.put(None) # Signal threads to stop | |
# Checks multiple batches of candidate points in a potential image all at once to see if the center point is 0 (i.e., a very crude test to see if the image is 'mostly dark', or not), and returns all of the valid points (16 has been more then enough so far, you may need to increase this number for images with more black points or maybe very long training runs...) | |
def get_valid_image_offsets_batched(image, height_size, width_size, batchsize, mult_to_check=16): | |
with torch.no_grad(): | |
zero_val = 1e-2 | |
max_height, max_width = image.squeeze().shape # Assuming 1 input image for now.... :'((((( | |
height_offsets = torch.randint(max_height-height_size, size=(batchsize*mult_to_check,)) | |
width_offsets = torch.randint(max_width-width_size, size=(batchsize*mult_to_check,)) | |
center_pixels = image[height_offsets+height_size//2, width_offsets+width_size//2] | |
valid_pixels = ~(center_pixels < zero_val) | |
heights_filtered = torch.masked_select(height_offsets, valid_pixels)[:batchsize] | |
widths_filtered = torch.masked_select(width_offsets, valid_pixels)[:batchsize] | |
if batchsize == 1: # if single batch, we may need to reduce the dimensions hiers <3 :')))) | |
heights_filtered = heights_filtered.squeeze() | |
widths_filtered = widths_filtered.squeeze() | |
return heights_filtered, widths_filtered | |
# TODO: TODO: TODO: TODO: TODO: Note to self: Gotta fold and fuse a few of these kernels, please! <3 :')))) | |
# Def gotta eventually v/hmap dis gui | |
def slow_cpu_iterative_image_slicer(input_image, label_image, chunk_size, rotation_padding, fliplr_chance=.5, rotate_degrees=180): | |
with torch.no_grad(): | |
height_offset, width_offset = get_valid_image_offsets_batched(input_image, chunk_size+rotation_padding, chunk_size+rotation_padding, batchsize=1) | |
input_image_batch_item = input_image.squeeze()[height_offset:height_offset+chunk_size+rotation_padding, width_offset:width_offset+chunk_size+rotation_padding] | |
label_image_batch_item = label_image.squeeze()[height_offset:height_offset+chunk_size+rotation_padding, width_offset:width_offset+chunk_size+rotation_padding] | |
return torch.stack((input_image_batch_item, label_image_batch_item), dim=0) | |
# this should be on the cpu in order to properly test dis gui | |
demo_input_data_raw = torch.ones((num_demo_layers, image_height, image_width), dtype=dtype, device=data_device) | |
demo_label_data_raw = torch.ones((1, image_height, image_width), dtype=dtype, device=data_device) | |
# make only label image pair for now (one label as the index for a stacked tensor containing all training images for a given dataset) | |
demo_input_label_data_pairs = ((demo_label_data_raw, demo_input_data_raw),) | |
# required for the multiprocessing, to avoid duplicating these source images in each process | |
for label_image, input_images in demo_input_label_data_pairs: | |
[input_image.share_memory_() for input_image in input_images] | |
label_image.share_memory_() | |
############################################# | |
# Dataloader Setup # | |
############################################# | |
num_cpus = os.cpu_count() | |
# the multiplier is hardcoded for now, you might want to adjust it, though there may be some memory/processing tradeoff if so.... | |
rotation_padding = round(.35*input_image_size) | |
print("num cpus available! : <3 :'))))", num_cpus) | |
batched_preprocessing_fn = BatchedPreprocessingFunction(image_size=input_image_size) | |
non_prefetched_image_dataset = CustomImageDataset(rotation_padding=rotation_padding, per_image_preprocessing_fn=batched_preprocessing_fn) | |
prefetched_image_dataset = CustomImageDataset(rotation_padding=rotation_padding) | |
non_prefetched_train_dataset_gpu_loader = iter(torch.utils.data.DataLoader(non_prefetched_image_dataset, batch_size=demo_batchsize, drop_last=True, shuffle=True, num_workers=num_cpus//2, pin_memory=False, persistent_workers=False, prefetch_factor=2)) | |
train_dataset_gpu_loader = iter(torch.utils.data.DataLoader(prefetched_image_dataset, batch_size=demo_batchsize, drop_last=True, shuffle=True, num_workers=num_cpus//2, pin_memory=False, persistent_workers=False, prefetch_factor=2)) | |
train_dataset_gpu_prefetcher = GPUPrefetcher(train_dataset_gpu_loader, batched_fn_to_apply=batched_preprocessing_fn, queue_size=6, num_threads=2) | |
# Sorta hackey for now to unroll the speed tests like this, but just keeping a flat-ish structure for flexibility. | |
# Non-prefetched dataloader | |
for _ in range(num_warmup_iters): | |
inputs, targets = next(non_prefetched_train_dataset_gpu_loader).to(device='cuda', memory_format=torch.channels_last, non_blocking=True).unsqueeze(2).unbind(1) | |
torch.cuda.synchronize() | |
non_prefetched_begin = time.time() | |
for _ in range(num_timing_iters): | |
inputs, targets = next(non_prefetched_train_dataset_gpu_loader).to(device='cuda', memory_format=torch.channels_last, non_blocking=True).unsqueeze(2).unbind(1) | |
torch.cuda.synchronize() | |
non_prefetched_end = time.time() | |
# Prefetched dataloader | |
for _ in range(num_warmup_iters): | |
inputs, targets = next(train_dataset_gpu_prefetcher).unsqueeze(2).unbind(1) | |
torch.cuda.synchronize() | |
prefetched_begin = time.time() | |
for _ in range(num_timing_iters): | |
inputs, targets = next(train_dataset_gpu_prefetcher).unsqueeze(2).unbind(1) | |
torch.cuda.synchronize() | |
prefetched_end = time.time() | |
non_prefetched_seconds_per_step = (non_prefetched_end-non_prefetched_begin)/num_timing_iters | |
prefetched_seconds_per_step = (prefetched_end-prefetched_begin)/num_timing_iters | |
print("-------------------------------------------------------------------------------------------------") | |
print(f"Avg non-prefetched train dataset gpu loader time per step (in seconds):\t {non_prefetched_seconds_per_step}\t|") | |
print(f"Avg prefetched train dataset gpu loader time per step (in seconds):\t {prefetched_seconds_per_step}\t|") | |
print("-------------------------------------------------------------------------------------------------") | |
print("\nSpeed factor of improvement: ", non_prefetched_seconds_per_step/prefetched_seconds_per_step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment