Created
October 1, 2020 07:36
-
-
Save bougui505/3079d55110a68e1a319ab26fe64f94fd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: UTF8 -*- | |
# Author: Guillaume Bouvier -- [email protected] | |
# https://research.pasteur.fr/en/member/guillaume-bouvier/ | |
# 2020-09-29 15:48:44 (UTC+0200) | |
import pymol.cmd as cmd | |
import torch | |
import sys | |
def print_progress(instr): | |
sys.stdout.write(f'{instr}\r') | |
sys.stdout.flush() | |
def get_cmap(coords, device, threshold=8.): | |
pdist = torch.cdist(coords, coords) | |
S = torch.nn.Sigmoid() | |
cmap = S(threshold - pdist) | |
cmap = cmap.to(device) | |
return cmap | |
def get_coords(pdbfilename, object, device, selection=None): | |
if selection is None: | |
selection = f'{object} and name CA' | |
cmd.load(pdbfilename, object=object) | |
cmd.remove(f'(not name CA) and {object}') | |
coords = cmd.get_coords(selection=selection) | |
coords = torch.from_numpy(coords) | |
coords = coords.to(device) | |
return coords | |
def permute(coords, weights): | |
out = coords.t().mm(weights).t() | |
# out = coords.t().mm(torch.nn.functional.softmax(weights, dim=1)).t() | |
return out | |
def build_rotation_matrix(alpha_beta_gamma, device): | |
alpha, beta, gamma = alpha_beta_gamma | |
tensor_0 = torch.zeros(1, device=device) | |
tensor_1 = torch.ones(1, device=device) | |
alpha = torch.ones(1, requires_grad=True, device=device) * alpha | |
beta = torch.ones(1, requires_grad=True, device=device) * beta | |
gamma = torch.ones(1, requires_grad=True, device=device) * gamma | |
RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), | |
torch.stack([tensor_0, torch.cos(alpha), -torch.sin(alpha)]), | |
torch.stack([tensor_0, torch.sin(alpha), torch.cos(alpha)])]).reshape(3, 3) | |
RY = torch.stack([torch.stack([torch.cos(beta), tensor_0, torch.sin(beta)]), | |
torch.stack([tensor_0, tensor_1, tensor_0]), | |
torch.stack([-torch.sin(beta), tensor_0, torch.cos(beta)])]).reshape(3, 3) | |
RZ = torch.stack([torch.stack([torch.cos(gamma), -torch.sin(gamma), tensor_0]), | |
torch.stack([torch.sin(gamma), torch.cos(gamma), tensor_0]), | |
torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3) | |
R = RZ.mm(RY).mm(RX) | |
return R | |
def build_reflection_matrix(abc, device): | |
# See: https://en.wikipedia.org/w/index.php?title=Transformation_matrix&oldid=976277111#Reflection | |
a, b, c = abc | |
A = torch.tensor([[1. - 2 * a**2, -2 * a * b, -2 * a * c], | |
[-2 * a * b, 1. - 2 * b**2, -2 * b * c], | |
[-2 * a * c, -2 * b * c, 1. - 2 * c**2]], device=device) | |
return A | |
def transform(coords, T, device): | |
""" | |
""" | |
coords_transform = coords - coords.mean() | |
coords_transform = coords.mm(T) | |
return coords_transform | |
def minsum(v, axis=1, n=2., eps=1e-6): | |
""" | |
A sum over v that returns a value close to the minima | |
""" | |
w = (1 / (1 / (v + eps) ** n).sum(axis=axis)) ** (1. / n) | |
return w | |
def anchor_loss(coords, anchors): | |
cdist = torch.cdist(coords - coords.mean(axis=0), anchors - anchors.mean(axis=0)) | |
mindists = torch.min(cdist, axis=1)[0] | |
# mindists = minsum(cdist, axis=1) | |
loss = (mindists**2).mean() | |
return loss | |
def cmap_loss(cmap_pred, cmap_true, w0=0.05): | |
cmap_pred = cmap_pred.flatten() | |
cmap_true = cmap_true.flatten() | |
bceloss = torch.nn.BCELoss(weight=(cmap_true + w0 * torch.ones_like(cmap_true))) | |
# bceloss = torch.nn.BCELoss(weight=cmap_true) | |
output = bceloss(cmap_pred, cmap_true) | |
return output | |
def align_structures(coords, coords_ref, device, n_iter): | |
alpha_beta_gamma = torch.randn(3, requires_grad=True, device=device) | |
abc = torch.randn(3, device=device, requires_grad=True) | |
optimizer = torch.optim.Adam([alpha_beta_gamma, abc], lr=1e-3) | |
for t in range(n_iter): | |
optimizer.zero_grad() | |
R = build_rotation_matrix(alpha_beta_gamma, device) | |
A = build_reflection_matrix(abc, device) | |
T = A.mm(R) | |
coords_out = transform(coords, T, device) | |
loss = anchor_loss(coords_out, coords_ref) | |
loss.backward() | |
optimizer.step() | |
if t % 100 == 99: | |
print_progress(f'{t+1}/{n_iter}: {loss}') | |
sys.stdout.write('\n') | |
return transform(coords, R, device='cpu') | |
def icp(coords, coords_ref, device, n_iter): | |
""" | |
Iterative Closest Point | |
""" | |
for t in range(n_iter): | |
cdist = torch.cdist(coords - coords.mean(axis=0), | |
coords_ref - coords_ref.mean(axis=0)) | |
mindists, argmins = torch.min(cdist, axis=1) | |
X, _ = torch.lstsq(coords_ref[argmins], coords) | |
coords = coords.mm(X[:3]) | |
rmsd = torch.sqrt((X[3:]**2).sum(axis=1).mean()) | |
print_progress(f'{t+1}/{n_iter}: {rmsd}') | |
return coords | |
def minimize(coords, cmap_ref, device, n_iter): | |
n = coords.shape[0] | |
# Permutation matrix | |
P = torch.eye(n, requires_grad=True, device=device) | |
optimizer_P = torch.optim.Adam([P, ], lr=1e-3) | |
for t in range(n_iter): | |
optimizer_P.zero_grad() | |
coords_pred = permute(coords, P) | |
cmap_pred = get_cmap(coords_pred, device=device) | |
loss_P = cmap_loss(cmap_pred, cmap_ref) | |
loss_P.backward() | |
optimizer_P.step() | |
if t % 100 == 99: | |
print_progress(f'{t+1}/{n_iter}: {loss_P}') | |
sys.stdout.write('\n') | |
return permute(coords, P) | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
coords_ref = get_coords('5v6p_.pdb', 'ref', device=device) | |
cmap_ref = get_cmap(coords_ref, device=device) | |
# cmap_ref[cmap_ref < 0.5] = 0. | |
# cmap_ref[cmap_ref >= 0.5] = 1. | |
coords_in = get_coords('map_to_model_5v6p_8637_.pdb', 'mod', device) | |
cmap_in = get_cmap(coords_in, device='cpu') | |
n = coords_in.shape[0] | |
coords_out = minimize(coords_in, cmap_ref, device, 10000) | |
cmap_out = get_cmap(coords_out, device='cpu').detach().numpy() | |
coords_out = coords_out.cpu().detach().numpy() | |
cmd.load_coords(coords_out, 'mod') | |
cmd.save('out.pdb', selection='mod') | |
plt.matshow(cmap_in.cpu().numpy()) | |
plt.savefig('cmap_in.png') | |
plt.matshow(cmap_ref.cpu().numpy()) | |
plt.savefig('cmap_ref.png') | |
plt.matshow(cmap_out) | |
plt.savefig('cmap_out.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment