Created
April 30, 2023 15:54
-
-
Save vhxs/e62450185daf5c8ea5be1178b4c7dc65 to your computer and use it in GitHub Desktop.
Visualize a convolution operator's matrix representation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import matplotlib.pyplot as plt | |
def generate_basis(num_channels, height, width): | |
zero_tensor = torch.zeros((num_channels, height, width)) | |
for i in range(num_channels): | |
for j in range(height): | |
for k in range(width): | |
basis_element = torch.clone(zero_tensor) | |
basis_element[i][j][k] = 1 | |
yield basis_element | |
def compute_matrix_from_filters(filters: torch.Tensor, height, width): | |
columns = [] | |
for basis_element in generate_basis(filters.shape[1], height, width): | |
basis_element_image = torch.nn.functional.conv2d(torch.unsqueeze(basis_element, 0), filters, padding="same") | |
columns.append(torch.flatten(basis_element_image)) | |
return torch.column_stack(columns) | |
if __name__ == "__main__": | |
num_input_channels = 4 | |
num_output_channels = 4 | |
height = 4 | |
width = 4 | |
ker_size = 2 | |
filters = torch.ones((num_output_channels, num_input_channels, ker_size, ker_size)) | |
matrix = compute_matrix_from_filters(filters, height, width) | |
plt.imshow(matrix, cmap="hot") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment