-
-
Save gglin001/9c35d2bbd8e5d995584855415980f99a to your computer and use it in GitHub Desktop.
Code for the Blog at:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from time import time | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
from skimage.util.shape import view_as_windows | |
torch.manual_seed(42) | |
def conv_2d(kernel, bias, x): | |
kernel_shape = kernel.shape[0] | |
output_shape = x.shape[0] - kernel_shape + 1 | |
result = np.zeros((output_shape, output_shape)) | |
for row in range(x.shape[0] - 1): | |
for col in range(x.shape[1] - 1): | |
window = x[row: row + kernel_shape, col: col + kernel_shape] | |
result[row, col] = np.sum(np.multiply(kernel, window)) | |
return result + bias | |
def memory_strided_im2col(x, kernel): | |
output_shape = (x.shape[0] - kernel.shape[0]) + 1 | |
return view_as_windows(x, kernel.shape).reshape(kernel.shape[0]*2, output_shape*output_shape) | |
def naive_im2col(x, kernel): | |
kernel_shape = kernel.shape[0] | |
rows = [] | |
# Assuming Padding = 0, stride = 1 | |
for row in range(x.shape[0] - 1): | |
for col in range(x.shape[1] - 1): | |
window = x[row: row + kernel_shape, col: col + kernel_shape] | |
rows.append(window.flatten()) | |
return np.transpose(np.array(rows)) | |
if __name__ == "__main__": | |
naive_time_log = [] | |
torch_time_log = [] | |
strided_time_log = [] | |
naive_im2col_time_log = [] | |
MAX_INPUT_SIZE = 300 | |
NUM_RUNS = 20 | |
input_size_list = list(range(3, MAX_INPUT_SIZE, 5)) | |
for input_size in tqdm(input_size_list): | |
torch_time = np.zeros(NUM_RUNS) | |
naive_time = np.zeros(NUM_RUNS) | |
im2col_time = np.zeros(NUM_RUNS) | |
strided_time = np.zeros(NUM_RUNS) | |
for run in range(NUM_RUNS): | |
conv = nn.Conv2d(1, 1, 2) | |
ip = torch.randint(low=0, high=10, size=( | |
1, 1, input_size, input_size), dtype=torch.float32) | |
conv.weight = nn.Parameter(torch.randint( | |
low=0, high=10, size=(1, 1, 2, 2), dtype=torch.float32)) | |
ip_np = ip.numpy().reshape(-1, ip.shape[-1]) | |
kernel_np = conv.weight.detach().squeeze().numpy() | |
bias_np = conv.bias.detach().squeeze().numpy() | |
start = time() | |
naive_conv = conv_2d(kernel_np, bias_np, ip_np) | |
naive_time[run] = time() - start | |
start = time() | |
np.dot(kernel_np.flatten(), naive_im2col(ip_np, kernel_np)) + bias_np | |
im2col_time[run]= time() - start | |
start = time() | |
torch_conv = conv(ip) | |
torch_time[run] = time() - start | |
start = time() | |
np.dot(kernel_np.flatten(), memory_strided_im2col(ip_np, kernel_np)) + bias_np | |
strided_time[run]= time() - start | |
naive_time_log.append(naive_time.mean()) | |
torch_time_log.append(torch_time.mean()) | |
strided_time_log.append(strided_time.mean()) | |
naive_im2col_time_log.append(im2col_time.mean()) | |
plt.plot(input_size_list, naive_time_log, | |
label='Naive Conv 2D', color='red') | |
plt.plot(input_size_list, torch_time_log, | |
label='PyTorch Conv 2D', color='blue') | |
plt.plot(input_size_list, naive_im2col_time_log, | |
label='Im2Col Conv 2D', color='green') | |
plt.plot(input_size_list, strided_time_log, | |
label='Mem Strided Im2Col Conv 2D', color='purple') | |
plt.xlabel('Size of Input (n x n)') | |
plt.ylabel('Execution Time (secs) - Log Scale') | |
plt.yscale("log") | |
plt.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment