Skip to content

Instantly share code, notes, and snippets.

@vhxs
Created April 30, 2023 15:54
Show Gist options
  • Save vhxs/e62450185daf5c8ea5be1178b4c7dc65 to your computer and use it in GitHub Desktop.
Save vhxs/e62450185daf5c8ea5be1178b4c7dc65 to your computer and use it in GitHub Desktop.
Visualize a convolution operator's matrix representation
import torch
import matplotlib.pyplot as plt
def generate_basis(num_channels, height, width):
zero_tensor = torch.zeros((num_channels, height, width))
for i in range(num_channels):
for j in range(height):
for k in range(width):
basis_element = torch.clone(zero_tensor)
basis_element[i][j][k] = 1
yield basis_element
def compute_matrix_from_filters(filters: torch.Tensor, height, width):
columns = []
for basis_element in generate_basis(filters.shape[1], height, width):
basis_element_image = torch.nn.functional.conv2d(torch.unsqueeze(basis_element, 0), filters, padding="same")
columns.append(torch.flatten(basis_element_image))
return torch.column_stack(columns)
if __name__ == "__main__":
num_input_channels = 4
num_output_channels = 4
height = 4
width = 4
ker_size = 2
filters = torch.ones((num_output_channels, num_input_channels, ker_size, ker_size))
matrix = compute_matrix_from_filters(filters, height, width)
plt.imshow(matrix, cmap="hot")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment