Created
January 19, 2021 19:37
-
-
Save GongXinyuu/4d542b5f6358ed922df75002b8cb8038 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Learning towards Minimum Hyperspherical Energy | |
| # https://github.com/wy1iu/MHE/blob/master/code/architecture.py | |
| # https://github.com/rmlin/CoMHE/blob/master/adversarial_projection/architecture.py | |
| @torch.no_grad() | |
| def hypersphereenergy(filt, filt_target=None, paired=False, model=None, power='0', device='cuda:0', mem_efficient=False): | |
| # TODO if filt_target is None: filt energy itself; if filt_target is not None: filt energy against filt_target | |
| # TODO paired: if size of filt and filt_target are same: only compare diagonal | |
| # tensorflow: [ksize, ksize, n_input, n_filt] | |
| # pytorch: [n_filt, n_input, ksize, ksize] | |
| n_filt = filt.size(0) | |
| filt = filt.view(n_filt, -1) | |
| if filt_target is not None: | |
| n_filt_target = filt_target.size(0) | |
| filt_target = filt_target.view(n_filt_target, -1) | |
| assert filt.size(1) == filt_target.size(1), str(filt.size()) + " v.s. " + str(filt_target.size()) | |
| if paired: | |
| assert n_filt == n_filt_target | |
| # TODO half_mhe not ready | |
| # if model =='half_mhe': | |
| # filt_neg = filt * -1 | |
| # filt = torch.cat([filt, filt_neg], dim=0) | |
| # n_filt *= 2 | |
| filt_norm = torch.sqrt(torch.sum(filt*filt, 1, keepdim=True) + 1e-4) | |
| if filt_target is not None: | |
| filt_norm_target = torch.sqrt(torch.sum(filt_target*filt_target, 1, keepdim=True) + 1e-4) | |
| if filt_target is None: | |
| norm_mat = torch.einsum('nc,mc->nm', [filt_norm, filt_norm]) | |
| inner_pro = torch.einsum('nc,mc->nm', [filt, filt]) | |
| else: | |
| norm_mat = torch.einsum('nc,mc->nm', [filt_norm, filt_norm_target]) | |
| inner_pro = torch.einsum('nc,mc->nm', [filt, filt_target]) | |
| if mem_efficient: del filt_norm; torch.cuda.empty_cache() | |
| inner_pro /= norm_mat | |
| if mem_efficient: del norm_mat; torch.cuda.empty_cache() | |
| if power =='0': | |
| cross_terms = torch.clamp(2.0 - 2.0 * inner_pro, 1e-4) # convert similarity to distance | |
| if mem_efficient: del inner_pro; torch.cuda.empty_cache() | |
| # final -= torch.tril(final)#, diagonal=-1) | |
| if filt_target is None: | |
| final = -torch.log(cross_terms) | |
| final = torch.tril(final, diagonal=-1) | |
| cnt = n_filt * (n_filt - 1) / 2.0 | |
| else: | |
| final = -torch.log(cross_terms) | |
| cnt = n_filt * n_filt_target | |
| if paired: | |
| assert final.size(0) == final.size(1) | |
| final = torch.diagonal(final) | |
| cnt = n_filt | |
| loss = 1 * final.sum() / cnt | |
| if mem_efficient: del final; torch.cuda.empty_cache() | |
| elif power =='1': | |
| cross_terms = torch.clamp(2.0 - 2.0 * inner_pro, 1e-4) # + torch.diag(torch.ones(n_filt)).to(device)) | |
| if mem_efficient: del inner_pro; torch.cuda.empty_cache() | |
| if filt_target is None: | |
| final = torch.pow(cross_terms + torch.diag(torch.ones(n_filt)).to(device), torch.ones_like(cross_terms) * (-0.5)) | |
| final = torch.tril(final, diagonal=-1) | |
| cnt = n_filt * (n_filt - 1) / 2.0 | |
| else: | |
| final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) | |
| cnt = n_filt * n_filt_target | |
| if paired: | |
| assert final.size(0) == final.size(1) | |
| final = torch.diagonal(final) | |
| cnt = n_filt | |
| # final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) | |
| # final -= torch.tril(final) | |
| # cnt = n_filt * (n_filt - 1) / 2.0 | |
| loss = 1 * final.sum() / cnt | |
| if mem_efficient: del final; torch.cuda.empty_cache() | |
| elif power =='2': | |
| # cross_terms = (torch.clamp(2.0 - 2.0 * inner_pro, 1e-4) + torch.diag(torch.ones(n_filt)).to(device)) | |
| # final = torch.pow(cross_terms, torch.ones_like(cross_terms).to(device) * (-1)) | |
| # final -= torch.tril(final) | |
| # cnt = n_filt * (n_filt - 1) / 2.0 | |
| # loss = 1 * final.sum() / cnt | |
| cross_terms = torch.clamp(2.0 - 2.0 * inner_pro, 1e-4) # + torch.diag(torch.ones(n_filt)).to(device)) | |
| if mem_efficient: del inner_pro; torch.cuda.empty_cache() | |
| if filt_target is None: | |
| final = torch.pow(cross_terms + torch.diag(torch.ones(n_filt)).to(device), torch.ones_like(cross_terms) * (-1)) | |
| final = torch.tril(final, diagonal=-1) | |
| cnt = n_filt * (n_filt - 1) / 2.0 | |
| else: | |
| final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-1)) | |
| cnt = n_filt * n_filt_target | |
| if paired: | |
| assert final.size(0) == final.size(1) | |
| final = torch.diagonal(final) | |
| cnt = n_filt | |
| loss = 1 * final.sum() / cnt | |
| if mem_efficient: del final; torch.cuda.empty_cache() | |
| elif power =='a0': | |
| acos = torch.acos(inner_pro)/math.pi | |
| acos += 1e-4 | |
| if mem_efficient: del inner_pro; torch.cuda.empty_cache() | |
| final = -torch.log(acos) | |
| if filt_target is None: | |
| final -= torch.tril(final) | |
| cnt = n_filt * (n_filt - 1) / 2.0 | |
| elif paired: | |
| assert final.size(0) == final.size(1) | |
| final = torch.diagonal(final) | |
| cnt = n_filt | |
| else: | |
| cnt = n_filt * n_filt_target | |
| loss = 1 * final.sum() / cnt | |
| if mem_efficient: del final; torch.cuda.empty_cache() | |
| elif power =='a1': | |
| acos = torch.acos(inner_pro)/math.pi | |
| acos += 1e-4 | |
| if mem_efficient: del inner_pro; torch.cuda.empty_cache() | |
| final = torch.pow(acos, torch.ones_like(acos) * (-1)) | |
| if filt_target is None: | |
| final -= torch.tril(final) | |
| cnt = n_filt * (n_filt - 1) / 2.0 | |
| elif paired: | |
| assert final.size(0) == final.size(1) | |
| final = torch.diagonal(final) | |
| cnt = n_filt | |
| else: | |
| cnt = n_filt * n_filt_target | |
| # final -= torch.tril(final) | |
| # cnt = n_filt * (n_filt - 1) / 2.0 | |
| loss = 1e-1 * final.sum() / cnt | |
| if mem_efficient: del final; torch.cuda.empty_cache() | |
| elif power =='a2': | |
| acos = torch.acos(inner_pro)/math.pi | |
| acos += 1e-4 | |
| if mem_efficient: del inner_pro; torch.cuda.empty_cache() | |
| final = torch.pow(acos, torch.ones_like(acos) * (-2)) | |
| if filt_target is None: | |
| final -= torch.tril(final) | |
| cnt = n_filt * (n_filt - 1) / 2.0 | |
| elif paired: | |
| assert final.size(0) == final.size(1) | |
| final = torch.diagonal(final) | |
| cnt = n_filt | |
| else: | |
| cnt = n_filt * n_filt_target | |
| # final -= torch.tril(final) | |
| # cnt = n_filt * (n_filt - 1) / 2.0 | |
| loss = 1e-1 * final.sum() / cnt | |
| if mem_efficient: del final; torch.cuda.empty_cache() | |
| return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment