Created
June 30, 2022 17:23
-
-
Save TrentBrick/9b733da9b1c2d8cfa0bb67921085a335 to your computer and use it in GitHub Desktop.
Potts Model Closed Form Expectation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# for use with a batch of sequences | |
def energy_torch(self, inp): | |
""" | |
Calculates in pytorch the hamiltonian energy. | |
Takes in the softmax over the sequences generated from the neural network. | |
Then computes the expected energy over this softmax in a vectorized way. | |
Parameters | |
---------- | |
sequences : np.array | |
Flattened protein sequences output from the neural network that have already been softmaxed batch_size x (protein_length x 20) | |
batch_size: int | |
Size of the batch to be able to perform reshaping | |
Returns | |
------- | |
torch.Tensor | |
torch.float32 matrix of size batch_size x 1 | |
""" | |
if not self.is_discrete: | |
batch_size = inp.shape[0] | |
# assumes that input is of the shape [batch x (L * properties)] | |
assert len(inp.shape) ==2, 'wrong shape!' | |
inp = inp.view( (batch_size, self.L, -1)) # decoder assumes 3D tensor. | |
# need to convert to a prob dist over the AAs | |
# then plug it into the score. | |
inp = self.decode(inp).view((batch_size, -1)) # this will return [batch_size x log pdf of AAs.] | |
#print('make sure no change!!! this is the h', self.h_torch) | |
# applying the vectorized EVH loss: | |
h_val = torch.matmul(inp, self.h_torch ) | |
j_val = torch.unsqueeze( torch.sum(inp * torch.matmul(inp, self.J_torch), dim=-1) /2, 1) | |
evh = j_val + h_val | |
return evh |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment