This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import matplotlib.pyplot as plt | |
from torch import nn, Tensor | |
from sklearn.datasets import make_moons | |
# Set device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
class DiscreteFlow(nn.Module): | |
def __init__(self, dim: int = 2, h: int = 128, v: int = 128): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2023 Meta Platforms, Inc. and affiliates | |
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# This file has been modified by Xinyou Wang on Jul 21, 2024 | |
# | |
# Original file was released under MIT, with the full license text | |
# available at https://github.com/facebookresearch/esm/blob/main/LICENSE | |
# | |
# This modified file is released under the same license. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def renum_pdb_str(pdb_str, Ls=None, renum=True, offset=1): | |
from string import ascii_uppercase, ascii_lowercase | |
assert len(Ls) == 2 | |
Ls = [int(i) for i in Ls] | |
alphabet_list = list(ascii_uppercase+ascii_lowercase) | |
if Ls is not None: | |
L_init = 0 | |
new_chain = {} | |
for L,c in zip(Ls, alphabet_list): | |
new_chain.update({i:c for i in range(L_init,L_init+L)}) |