Last active
January 17, 2024 17:12
-
-
Save hsm207/7bfbe524bfd9b60d1a9e209759064180 to your computer and use it in GitHub Desktop.
Code to accompany my blog post at https://bit.ly/2KfmQ76
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"toc": true | |
}, | |
"source": [ | |
"<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n", | |
"<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Introduction\" data-toc-modified-id=\"Introduction-1\"><span class=\"toc-item-num\">1 </span>Introduction</a></span></li><li><span><a href=\"#Libraries\" data-toc-modified-id=\"Libraries-2\"><span class=\"toc-item-num\">2 </span>Libraries</a></span></li><li><span><a href=\"#Explanation\" data-toc-modified-id=\"Explanation-3\"><span class=\"toc-item-num\">3 </span>Explanation</a></span><ul class=\"toc-item\"><li><span><a href=\"#Example:-Single-channel-image-and-convolution-has-only-1-output-channel\" data-toc-modified-id=\"Example:-Single-channel-image-and-convolution-has-only-1-output-channel-3.1\"><span class=\"toc-item-num\">3.1 </span>Example: Single channel image and convolution has only 1 output channel</a></span></li><li><span><a href=\"#Bigger-Example\" data-toc-modified-id=\"Bigger-Example-3.2\"><span class=\"toc-item-num\">3.2 </span>Bigger Example</a></span></li></ul></li></ul></div>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Introduction" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This notebook describes how to express a 2D Convolution in terms of matrix multiplication:" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Libraries" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We only need pytorch:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'1.0.0'" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.__version__" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Explanation" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Example: Single channel image and convolution has only 1 output channel" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let $X$ be a $4 \\times 4$ single channel input image:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[ 1., 2., 3., 4.],\n", | |
" [ 5., 6., 7., 8.],\n", | |
" [ 9., 10., 11., 12.],\n", | |
" [13., 14., 15., 16.]]]])" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X = torch.arange(1, 17).view(-1, 1, 4, 4).float()\n", | |
"X" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's define a 2D convolution with the following properties:\n", | |
"\n", | |
"* kernel size: $2 \\times 2$\n", | |
"* padding: 0\n", | |
"* stride: 1\n", | |
"* bias: 0\n", | |
"* output channels: 1\n", | |
"* initial weights, $W$: $\\begin{bmatrix}\n", | |
" 1 & 2 \\\\\n", | |
" 3 & 4 \\\\ \n", | |
"\\end{bmatrix}$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, stride=1)\n", | |
"W = torch.arange(1, 5).view(-1, 1, 2, 2).float()\n", | |
"\n", | |
"conv.weight.data = W\n", | |
"conv.bias.data = torch.zeros([1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Given the dimension of the input image and the 2D convolution, the size of the output (height and wdith) after applying the convolution to the image is:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def output_size_after_convolution(image_dim, n_padding, kernel_size, stride):\n", | |
" return (image_dim - kernel_size + 2 * n_padding)/stride + 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"n_output = output_size_after_convolution(image_dim=4, kernel_size=2, n_padding=0, stride=1)\n", | |
"n_output = int(n_output)\n", | |
"n_output" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Compute the result of doing the convolution using PyTorch's built-in function:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[ 44., 54., 64.],\n", | |
" [ 84., 94., 104.],\n", | |
" [124., 134., 144.]]]], grad_fn=<MkldnnConvolutionBackward>)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv2d_torch = conv(X)\n", | |
"result_conv2d_torch" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we compute the result of the convolution ourselves using matrix multiplication.\n", | |
"\n", | |
"First, we can express all the image patches that the kernel will get passed to the kernel as a $4 \\times 9$ matrix:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],\n", | |
" [ 2., 3., 4., 6., 7., 8., 10., 11., 12.],\n", | |
" [ 5., 6., 7., 9., 10., 11., 13., 14., 15.],\n", | |
" [ 6., 7., 8., 10., 11., 12., 14., 15., 16.]]])" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"unfold = torch.nn.Unfold(kernel_size=2, padding=0, stride=1)\n", | |
"X_unfold = unfold(X)\n", | |
"X_unfold" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Notice that the i-th column corresponds to the image patch seen by the i-th output neuron." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Similarly, we can express the kernel of the convolution's operator as a $1 \\times 4$ matrix:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[1., 2., 3., 4.]]])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W_unfold = unfold(W).transpose(2, 1)\n", | |
"W_unfold" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Notice that each row represents the flattened weights of the i-th kernel:" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we can do the multiplication:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[ 44., 54., 64., 84., 94., 104., 124., 134., 144.]]])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv2d_matmul = torch.matmul(W_unfold, X_unfold)\n", | |
"result_conv2d_matmul" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"All that is left is to reshape it to the expected shape:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[ 44., 54., 64.],\n", | |
" [ 84., 94., 104.],\n", | |
" [124., 134., 144.]]]])" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv2d_matmul = result_conv2d_matmul.view(-1, conv.out_channels, n_output, n_output)\n", | |
"result_conv2d_matmul" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Check that results are as expected:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert torch.equal(result_conv2d_matmul, result_conv2d_torch)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Bigger Example" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's experiment with a $4 \\times 4 \\times 3$ image, $X$:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[ 1., 2., 3., 4.],\n", | |
" [ 5., 6., 7., 8.],\n", | |
" [ 9., 10., 11., 12.],\n", | |
" [13., 14., 15., 16.]],\n", | |
"\n", | |
" [[17., 18., 19., 20.],\n", | |
" [21., 22., 23., 24.],\n", | |
" [25., 26., 27., 28.],\n", | |
" [29., 30., 31., 32.]],\n", | |
"\n", | |
" [[33., 34., 35., 36.],\n", | |
" [37., 38., 39., 40.],\n", | |
" [41., 42., 43., 44.],\n", | |
" [45., 46., 47., 48.]]]])" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X = torch.arange(1, 49).view(-1, 3, 4, 4).float()\n", | |
"X" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Define the convolution operation to have kernel size $2 \\times 2$, $0$ padding, stride $1$, $0$ bias and $2$ output channels:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[ 1., 2.],\n", | |
" [ 3., 4.]],\n", | |
"\n", | |
" [[ 5., 6.],\n", | |
" [ 7., 8.]],\n", | |
"\n", | |
" [[ 9., 10.],\n", | |
" [11., 12.]]],\n", | |
"\n", | |
"\n", | |
" [[[13., 14.],\n", | |
" [15., 16.]],\n", | |
"\n", | |
" [[17., 18.],\n", | |
" [19., 20.]],\n", | |
"\n", | |
" [[21., 22.],\n", | |
" [23., 24.]]]])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"conv = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=2, stride=1)\n", | |
"W = torch.arange(1, 25).view(-1, conv.in_channels, 2, 2).float()\n", | |
"\n", | |
"conv.weight.data = W\n", | |
"conv.bias.data = torch.zeros([conv.out_channels])\n", | |
"conv.weight.data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Perform the convolution with PyTorch's built-in functions:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[2060., 2138., 2216.],\n", | |
" [2372., 2450., 2528.],\n", | |
" [2684., 2762., 2840.]],\n", | |
"\n", | |
" [[4868., 5090., 5312.],\n", | |
" [5756., 5978., 6200.],\n", | |
" [6644., 6866., 7088.]]]], grad_fn=<MkldnnConvolutionBackward>)" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv2d_torch = conv(X)\n", | |
"result_conv2d_torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[2060., 2138., 2216.],\n", | |
" [2372., 2450., 2528.],\n", | |
" [2684., 2762., 2840.]],\n", | |
"\n", | |
" [[4868., 5090., 5312.],\n", | |
" [5756., 5978., 6200.],\n", | |
" [6644., 6866., 7088.]]]], grad_fn=<MkldnnConvolutionBackward>)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv2d_torch = conv(X)\n", | |
"result_conv2d_torch" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Compute the convolution using matrix multiplication:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],\n", | |
" [ 2., 3., 4., 6., 7., 8., 10., 11., 12.],\n", | |
" [ 5., 6., 7., 9., 10., 11., 13., 14., 15.],\n", | |
" [ 6., 7., 8., 10., 11., 12., 14., 15., 16.],\n", | |
" [17., 18., 19., 21., 22., 23., 25., 26., 27.],\n", | |
" [18., 19., 20., 22., 23., 24., 26., 27., 28.],\n", | |
" [21., 22., 23., 25., 26., 27., 29., 30., 31.],\n", | |
" [22., 23., 24., 26., 27., 28., 30., 31., 32.],\n", | |
" [33., 34., 35., 37., 38., 39., 41., 42., 43.],\n", | |
" [34., 35., 36., 38., 39., 40., 42., 43., 44.],\n", | |
" [37., 38., 39., 41., 42., 43., 45., 46., 47.],\n", | |
" [38., 39., 40., 42., 43., 44., 46., 47., 48.]]])" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_unfold = unfold(X)\n", | |
"X_unfold" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]],\n", | |
"\n", | |
" [[13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]]])" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W_unfold = unfold(W).transpose(2, 1)\n", | |
"W_unfold" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[2060., 2138., 2216., 2372., 2450., 2528., 2684., 2762., 2840.]],\n", | |
"\n", | |
" [[4868., 5090., 5312., 5756., 5978., 6200., 6644., 6866., 7088.]]])" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv2d_matmul = torch.matmul(W_unfold, X_unfold)\n", | |
"result_conv2d_matmul" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[2060., 2138., 2216.],\n", | |
" [2372., 2450., 2528.],\n", | |
" [2684., 2762., 2840.]],\n", | |
"\n", | |
" [[4868., 5090., 5312.],\n", | |
" [5756., 5978., 6200.],\n", | |
" [6644., 6866., 7088.]]]])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv2d_matmul = result_conv2d_matmul.view(-1, conv.out_channels, n_output, n_output)\n", | |
"result_conv2d_matmul" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert torch.equal(result_conv2d_torch, result_conv2d_matmul)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Appendix" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 1D Convolution" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Specify an input vector:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]],\n", | |
"\n", | |
" [[13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]],\n", | |
"\n", | |
" [[25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36.]]]])" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# (batch size, channel, height, width)\n", | |
"X = torch.arange(1, 37).view(-1, 3, 1, 12).float()\n", | |
"X" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Params for the convolution operation:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"in_channels = X.shape[1]\n", | |
"out_channels = 2\n", | |
"kernel_size = (1, 4)\n", | |
"stride = 2\n", | |
"padding = 0\n", | |
"bias = False" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Compute the expected output size of the convolution:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(1, 5)" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"_, _, image_h, image_w = X.shape\n", | |
"kernel_h, kernel_w = kernel_size\n", | |
"\n", | |
"output_h = output_size_after_convolution(image_dim=image_h, n_padding=padding, kernel_size=kernel_h, stride=stride)\n", | |
"output_w = output_size_after_convolution(image_dim=image_w, n_padding=padding, kernel_size=kernel_w, stride=stride)\n", | |
"\n", | |
"output_h = int(output_h)\n", | |
"output_w = int(output_w)\n", | |
"\n", | |
"output_h, output_w" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Define and perform the convolution operation using pytorch:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"conv = torch.nn.Conv1d(in_channels=in_channels,\n", | |
" out_channels=out_channels,\n", | |
" kernel_size=kernel_size,\n", | |
" stride=stride,\n", | |
" padding=padding,\n", | |
" bias=bias)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"initial weights in the convolution operation:\n", | |
"tensor([[[[ 1., 2., 3., 4.]],\n", | |
"\n", | |
" [[ 5., 6., 7., 8.]],\n", | |
"\n", | |
" [[ 9., 10., 11., 12.]]],\n", | |
"\n", | |
"\n", | |
" [[[13., 14., 15., 16.]],\n", | |
"\n", | |
" [[17., 18., 19., 20.]],\n", | |
"\n", | |
" [[21., 22., 23., 24.]]]])\n" | |
] | |
} | |
], | |
"source": [ | |
"W = torch.arange(1, kernel_w * in_channels * out_channels + 1).view(out_channels, in_channels, 1, kernel_w).float()\n", | |
"print(f'initial weights in the convolution operation:\\n{W}')\n", | |
"\n", | |
"conv.weight.data = W" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[1530., 1686., 1842., 1998., 2154.]],\n", | |
"\n", | |
" [[3618., 4062., 4506., 4950., 5394.]]]],\n", | |
" grad_fn=<MkldnnConvolutionBackward>)" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv1d_torch = conv(X)\n", | |
"result_conv1d_torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 2, 1, 5])" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv1d_torch.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we do the convolution ourselves using matrix multiplication:" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Create the matrix of image patches:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"unfold = torch.nn.Unfold(kernel_size=kernel_size,\n", | |
" padding=padding,\n", | |
" stride=stride)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[ 1., 3., 5., 7., 9.],\n", | |
" [ 2., 4., 6., 8., 10.],\n", | |
" [ 3., 5., 7., 9., 11.],\n", | |
" [ 4., 6., 8., 10., 12.],\n", | |
" [13., 15., 17., 19., 21.],\n", | |
" [14., 16., 18., 20., 22.],\n", | |
" [15., 17., 19., 21., 23.],\n", | |
" [16., 18., 20., 22., 24.],\n", | |
" [25., 27., 29., 31., 33.],\n", | |
" [26., 28., 30., 32., 34.],\n", | |
" [27., 29., 31., 33., 35.],\n", | |
" [28., 30., 32., 34., 36.]]])" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_unfold = unfold(X)\n", | |
"X_unfold" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 12, 5])" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_unfold.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Also unfold the parameters in the convolution operation:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],\n", | |
" [13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]])" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W_unfold = W.view(-1, kernel_h * kernel_w * in_channels)\n", | |
"W_unfold" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 12])" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W_unfold.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Perform the matrix multiplication and reshape to the correct output:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[1530., 1686., 1842., 1998., 2154.],\n", | |
" [3618., 4062., 4506., 4950., 5394.]]])" | |
] | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv1d_matmul = torch.matmul(W_unfold, X_unfold)\n", | |
"result_conv1d_matmul" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[1530., 1686., 1842., 1998., 2154.]],\n", | |
"\n", | |
" [[3618., 4062., 4506., 4950., 5394.]]]])" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_conv1d_matmul = result_conv1d_matmul.view(-1, out_channels, output_h, output_w)\n", | |
"result_conv1d_matmul" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert torch.equal(result_conv1d_torch, result_conv1d_matmul)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "fastai" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.2" | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"toc_cell": true, | |
"toc_position": {}, | |
"toc_section_display": "block", | |
"toc_window_display": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment