Skip to content

Instantly share code, notes, and snippets.

@dyd1234
Created November 5, 2024 02:34
Show Gist options
  • Save dyd1234/44a00924c3517a2203a298af16d70a49 to your computer and use it in GitHub Desktop.
Save dyd1234/44a00924c3517a2203a298af16d70a49 to your computer and use it in GitHub Desktop.
Repair for the rotation and scaling
# --------------------------------------------------------------------
# ... some modules and imports is not showing here
# --------------------------------------------------------------------
# REPAIR implementations
def is_batchnorm(module): #
return isinstance(module, nn.modules.batchnorm._BatchNorm)
# if we need to do with rotation?
# basic implementation
def reset_bn_stats(model, loader, device, epochs=1):
num_data = 0
# resetting stats to baseline first as below is necessary for stability
for m in model.modules():
if is_batchnorm(m):
m.momentum = None
m.reset_running_stats()
# run a single train epoch
model.train()
enough_data = False
for _ in range(epochs):
if enough_data:
break
with torch.no_grad():
for images, _ in loader:
output = model(images.to(device))
num_data += len(images)
if num_data >= 1000:
print("Enough data for REPAIR")
enough_data = True
break
model.eval()
return model
# try with SCN type and other stuff
# pls remember that the model is SCN YOLO
# pls check this one to see if this one is ok to make the bn problem fixed
def reset_bn_stats_scn_rotate(model, train_loader, device, if_half=True, epochs=1): # modify it later
num_data = 0
# resetting stats to baseline first as below is necessary for stability
for m in model.modules():
if is_batchnorm(m):
m.momentum = None
m.reset_running_stats()
# run a single train epoch
model.train()
enough_data = False
for _ in range(epochs):
pbar = enumerate(train_loader)
if RANK in {-1, 0}:
pbar = tqdm(pbar, total=len(train_loader), bar_format=TQDM_BAR_FORMAT) # progress bar
if enough_data:
break
with torch.no_grad(): # check how it goes?
# for images, _ , _ in train_loader: # OK to go
for i, (images, _, _, _) in pbar:
# do the transformations
angle = random.uniform(0, 360) # anti clock-wise
images = TF.rotate(images, -angle) # to clock wise
# default to set them on GPU
images = images.to(device, non_blocking=True)
images = images.half() if if_half else images.float() # uint8 to fp16/32
Hyper_X = transform_angle(angle).to(device)
if if_half == True: #
Hyper_X = Hyper_X.half() # set to half type
# seenms I only need to rotate but I dont need to calc the loss, nice!
output = model(images.to(device), Hyper_X) # SCN type, always using the half
num_data += len(images) # why 1000?
# if num_data >= 1000:
# print("Enough data for REPAIR")
# enough_data = True
# break
print("Enough data for REPAIR")
model.eval()
return model
# for scaling case
def reset_bn_stats_scn_scaling(model, train_loader, device, epochs=1): # modify it later
num_data = 0 #
# resetting stats to baseline first as below is necessary for stability
for m in model.modules():
if is_batchnorm(m):
m.momentum = None
m.reset_running_stats()
# run a single train epoch
model.train()
enough_data = False
for _ in range(epochs):
if enough_data:
break
with torch.no_grad():
for images, _ in train_loader: # OK to go
# do the transformations
# angle = random.uniform(0, 360) # anti clock-wise
# images = TF.rotate(images, -angle) # to clock wise
# Hyper_X = transform_angle(angle).to(device) #
scale = random.uniform(0.2, 1.8) # make it meaningful?
imgs = TF.affine(imgs, scale=scale, angle=0, translate=(0, 0), shear=0.0) #
Hyper_X = Tensor([scale]).to(device)
# seenms I only need to rotate but I dont need to calc the loss, nice!
output = model(images.to(device), Hyper_X) # SCN type
num_data += len(images)
# if num_data >= 1000:
# print("Enough data for REPAIR")
# enough_data = True
# break
model.eval()
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment