Created
November 5, 2024 02:34
-
-
Save dyd1234/44a00924c3517a2203a298af16d70a49 to your computer and use it in GitHub Desktop.
Repair for the rotation and scaling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -------------------------------------------------------------------- | |
# ... some modules and imports is not showing here | |
# -------------------------------------------------------------------- | |
# REPAIR implementations | |
def is_batchnorm(module): # | |
return isinstance(module, nn.modules.batchnorm._BatchNorm) | |
# if we need to do with rotation? | |
# basic implementation | |
def reset_bn_stats(model, loader, device, epochs=1): | |
num_data = 0 | |
# resetting stats to baseline first as below is necessary for stability | |
for m in model.modules(): | |
if is_batchnorm(m): | |
m.momentum = None | |
m.reset_running_stats() | |
# run a single train epoch | |
model.train() | |
enough_data = False | |
for _ in range(epochs): | |
if enough_data: | |
break | |
with torch.no_grad(): | |
for images, _ in loader: | |
output = model(images.to(device)) | |
num_data += len(images) | |
if num_data >= 1000: | |
print("Enough data for REPAIR") | |
enough_data = True | |
break | |
model.eval() | |
return model | |
# try with SCN type and other stuff | |
# pls remember that the model is SCN YOLO | |
# pls check this one to see if this one is ok to make the bn problem fixed | |
def reset_bn_stats_scn_rotate(model, train_loader, device, if_half=True, epochs=1): # modify it later | |
num_data = 0 | |
# resetting stats to baseline first as below is necessary for stability | |
for m in model.modules(): | |
if is_batchnorm(m): | |
m.momentum = None | |
m.reset_running_stats() | |
# run a single train epoch | |
model.train() | |
enough_data = False | |
for _ in range(epochs): | |
pbar = enumerate(train_loader) | |
if RANK in {-1, 0}: | |
pbar = tqdm(pbar, total=len(train_loader), bar_format=TQDM_BAR_FORMAT) # progress bar | |
if enough_data: | |
break | |
with torch.no_grad(): # check how it goes? | |
# for images, _ , _ in train_loader: # OK to go | |
for i, (images, _, _, _) in pbar: | |
# do the transformations | |
angle = random.uniform(0, 360) # anti clock-wise | |
images = TF.rotate(images, -angle) # to clock wise | |
# default to set them on GPU | |
images = images.to(device, non_blocking=True) | |
images = images.half() if if_half else images.float() # uint8 to fp16/32 | |
Hyper_X = transform_angle(angle).to(device) | |
if if_half == True: # | |
Hyper_X = Hyper_X.half() # set to half type | |
# seenms I only need to rotate but I dont need to calc the loss, nice! | |
output = model(images.to(device), Hyper_X) # SCN type, always using the half | |
num_data += len(images) # why 1000? | |
# if num_data >= 1000: | |
# print("Enough data for REPAIR") | |
# enough_data = True | |
# break | |
print("Enough data for REPAIR") | |
model.eval() | |
return model | |
# for scaling case | |
def reset_bn_stats_scn_scaling(model, train_loader, device, epochs=1): # modify it later | |
num_data = 0 # | |
# resetting stats to baseline first as below is necessary for stability | |
for m in model.modules(): | |
if is_batchnorm(m): | |
m.momentum = None | |
m.reset_running_stats() | |
# run a single train epoch | |
model.train() | |
enough_data = False | |
for _ in range(epochs): | |
if enough_data: | |
break | |
with torch.no_grad(): | |
for images, _ in train_loader: # OK to go | |
# do the transformations | |
# angle = random.uniform(0, 360) # anti clock-wise | |
# images = TF.rotate(images, -angle) # to clock wise | |
# Hyper_X = transform_angle(angle).to(device) # | |
scale = random.uniform(0.2, 1.8) # make it meaningful? | |
imgs = TF.affine(imgs, scale=scale, angle=0, translate=(0, 0), shear=0.0) # | |
Hyper_X = Tensor([scale]).to(device) | |
# seenms I only need to rotate but I dont need to calc the loss, nice! | |
output = model(images.to(device), Hyper_X) # SCN type | |
num_data += len(images) | |
# if num_data >= 1000: | |
# print("Enough data for REPAIR") | |
# enough_data = True | |
# break | |
model.eval() | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment