This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This gist demonstrates the equivalence between the existing CLIP `AttentionPool2d` | |
and the proposed `AttentionPool2dFix`, which only computes attention where needed. | |
""" | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base: zero, one | |
function uint_type(t) | |
s = sizeof(t) | |
if s == 2 | |
Int16 | |
elseif s == 4 | |
Int32 | |
elseif s == 8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def crop_center(image, crop_size): | |
""" Crops the central region of the image with shape crop_size""" | |
shape = image.shape | |
start0 = shape[0]//2-(crop_size[0]//2) | |
start1 = shape[1]//2-(crop_size[1]//2) | |
return img[start0:start0+crop_size[0],start1:start1+crop_size[1]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.feature_extraction.image import extract_patches | |
def conv2d(inputs, filters): | |
""" | |
Args: | |
inputs (np.ndarray): NHWC | |
filters (np.ndarray): | |
with shape [filter_height, filter_width, in_channels, out_channels] | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
def shuffle(iterator, buffer_size): | |
""" Uses a buffer to randomly shuffle items from an iterator """ | |
# Fill the buffer | |
buffer = []; cnt = 0 | |
item = next(iterator, None) | |
while not item is None and cnt < buffer_size: | |
buffer.append(item) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numpy import linalg as la | |
""" | |
https://www.ltu.se/cms_fs/1.51590!/svd-fitting.pdf | |
""" | |
def get_rigid_transform(X, Y): | |
""" | |
Calculates the rigid transform (translation, rotation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class Node(object): | |
""" | |
Base class for nodes in the network. | |
Arguments: | |
`inbound_nodes`: A list of nodes with edges into this node. | |
""" | |
def __init__(self, inbound_nodes=[]): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def euclidean(A, B): | |
return np.sqrt(((np.expand_dims(A, axis=-1) - B.T)**2).sum(axis=1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.stats as scs | |
import matplotlib.pyplot as plt | |
def plot(ax, x, y, label): | |
lines = ax.plot(x, y, label=label, lw=2) | |
ax.fill_between(x, 0, y, alpha=0.2, color=lines[0].get_c()) | |
if __name__ == '__main__': | |
# Simulated data for variations a and b |