This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as fnn | |
import torch.optim as onn | |
class Transpose(nn.Module): | |
def __init__(self): | |
super(Transpose, self).__init__() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def distance_permutation(array, d): | |
""" | |
This function returns a staggered permutation of a sequence an the corresponding | |
permuted indices, such that no item is allowed to move more than d steps from | |
it's original position. | |
""" | |
rng = range(len(array)) | |
for i in rng: | |
swap = randint(0, d) | |
if i - swap < rng[i] < i + swap: |