Skip to content

Instantly share code, notes, and snippets.

@GongXinyuu
Created January 19, 2021 19:37
Show Gist options
  • Select an option

  • Save GongXinyuu/4d542b5f6358ed922df75002b8cb8038 to your computer and use it in GitHub Desktop.

Select an option

Save GongXinyuu/4d542b5f6358ed922df75002b8cb8038 to your computer and use it in GitHub Desktop.
# Learning towards Minimum Hyperspherical Energy
# https://github.com/wy1iu/MHE/blob/master/code/architecture.py
# https://github.com/rmlin/CoMHE/blob/master/adversarial_projection/architecture.py
@torch.no_grad()
def hypersphereenergy(filt, filt_target=None, paired=False, model=None, power='0', device='cuda:0', mem_efficient=False):
# TODO if filt_target is None: filt energy itself; if filt_target is not None: filt energy against filt_target
# TODO paired: if size of filt and filt_target are same: only compare diagonal
# tensorflow: [ksize, ksize, n_input, n_filt]
# pytorch: [n_filt, n_input, ksize, ksize]
n_filt = filt.size(0)
filt = filt.view(n_filt, -1)
if filt_target is not None:
n_filt_target = filt_target.size(0)
filt_target = filt_target.view(n_filt_target, -1)
assert filt.size(1) == filt_target.size(1), str(filt.size()) + " v.s. " + str(filt_target.size())
if paired:
assert n_filt == n_filt_target
# TODO half_mhe not ready
# if model =='half_mhe':
# filt_neg = filt * -1
# filt = torch.cat([filt, filt_neg], dim=0)
# n_filt *= 2
filt_norm = torch.sqrt(torch.sum(filt*filt, 1, keepdim=True) + 1e-4)
if filt_target is not None:
filt_norm_target = torch.sqrt(torch.sum(filt_target*filt_target, 1, keepdim=True) + 1e-4)
if filt_target is None:
norm_mat = torch.einsum('nc,mc->nm', [filt_norm, filt_norm])
inner_pro = torch.einsum('nc,mc->nm', [filt, filt])
else:
norm_mat = torch.einsum('nc,mc->nm', [filt_norm, filt_norm_target])
inner_pro = torch.einsum('nc,mc->nm', [filt, filt_target])
if mem_efficient: del filt_norm; torch.cuda.empty_cache()
inner_pro /= norm_mat
if mem_efficient: del norm_mat; torch.cuda.empty_cache()
if power =='0':
cross_terms = torch.clamp(2.0 - 2.0 * inner_pro, 1e-4) # convert similarity to distance
if mem_efficient: del inner_pro; torch.cuda.empty_cache()
# final -= torch.tril(final)#, diagonal=-1)
if filt_target is None:
final = -torch.log(cross_terms)
final = torch.tril(final, diagonal=-1)
cnt = n_filt * (n_filt - 1) / 2.0
else:
final = -torch.log(cross_terms)
cnt = n_filt * n_filt_target
if paired:
assert final.size(0) == final.size(1)
final = torch.diagonal(final)
cnt = n_filt
loss = 1 * final.sum() / cnt
if mem_efficient: del final; torch.cuda.empty_cache()
elif power =='1':
cross_terms = torch.clamp(2.0 - 2.0 * inner_pro, 1e-4) # + torch.diag(torch.ones(n_filt)).to(device))
if mem_efficient: del inner_pro; torch.cuda.empty_cache()
if filt_target is None:
final = torch.pow(cross_terms + torch.diag(torch.ones(n_filt)).to(device), torch.ones_like(cross_terms) * (-0.5))
final = torch.tril(final, diagonal=-1)
cnt = n_filt * (n_filt - 1) / 2.0
else:
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
cnt = n_filt * n_filt_target
if paired:
assert final.size(0) == final.size(1)
final = torch.diagonal(final)
cnt = n_filt
# final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
# final -= torch.tril(final)
# cnt = n_filt * (n_filt - 1) / 2.0
loss = 1 * final.sum() / cnt
if mem_efficient: del final; torch.cuda.empty_cache()
elif power =='2':
# cross_terms = (torch.clamp(2.0 - 2.0 * inner_pro, 1e-4) + torch.diag(torch.ones(n_filt)).to(device))
# final = torch.pow(cross_terms, torch.ones_like(cross_terms).to(device) * (-1))
# final -= torch.tril(final)
# cnt = n_filt * (n_filt - 1) / 2.0
# loss = 1 * final.sum() / cnt
cross_terms = torch.clamp(2.0 - 2.0 * inner_pro, 1e-4) # + torch.diag(torch.ones(n_filt)).to(device))
if mem_efficient: del inner_pro; torch.cuda.empty_cache()
if filt_target is None:
final = torch.pow(cross_terms + torch.diag(torch.ones(n_filt)).to(device), torch.ones_like(cross_terms) * (-1))
final = torch.tril(final, diagonal=-1)
cnt = n_filt * (n_filt - 1) / 2.0
else:
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-1))
cnt = n_filt * n_filt_target
if paired:
assert final.size(0) == final.size(1)
final = torch.diagonal(final)
cnt = n_filt
loss = 1 * final.sum() / cnt
if mem_efficient: del final; torch.cuda.empty_cache()
elif power =='a0':
acos = torch.acos(inner_pro)/math.pi
acos += 1e-4
if mem_efficient: del inner_pro; torch.cuda.empty_cache()
final = -torch.log(acos)
if filt_target is None:
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
elif paired:
assert final.size(0) == final.size(1)
final = torch.diagonal(final)
cnt = n_filt
else:
cnt = n_filt * n_filt_target
loss = 1 * final.sum() / cnt
if mem_efficient: del final; torch.cuda.empty_cache()
elif power =='a1':
acos = torch.acos(inner_pro)/math.pi
acos += 1e-4
if mem_efficient: del inner_pro; torch.cuda.empty_cache()
final = torch.pow(acos, torch.ones_like(acos) * (-1))
if filt_target is None:
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
elif paired:
assert final.size(0) == final.size(1)
final = torch.diagonal(final)
cnt = n_filt
else:
cnt = n_filt * n_filt_target
# final -= torch.tril(final)
# cnt = n_filt * (n_filt - 1) / 2.0
loss = 1e-1 * final.sum() / cnt
if mem_efficient: del final; torch.cuda.empty_cache()
elif power =='a2':
acos = torch.acos(inner_pro)/math.pi
acos += 1e-4
if mem_efficient: del inner_pro; torch.cuda.empty_cache()
final = torch.pow(acos, torch.ones_like(acos) * (-2))
if filt_target is None:
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
elif paired:
assert final.size(0) == final.size(1)
final = torch.diagonal(final)
cnt = n_filt
else:
cnt = n_filt * n_filt_target
# final -= torch.tril(final)
# cnt = n_filt * (n_filt - 1) / 2.0
loss = 1e-1 * final.sum() / cnt
if mem_efficient: del final; torch.cuda.empty_cache()
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment