Last active
March 16, 2020 15:25
-
-
Save lpraat/3f7f92fb78687b0897294b28d2ee015b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Example showing that backward pass of conv layer can be done using transposed convolution | |
| """ | |
| import numpy as np | |
| a_prev = np.array([ | |
| [1,1,1,1], | |
| [2,2,2,2], | |
| [3,3,3,3], | |
| [4,4,4,4] | |
| ], dtype=np.float32).reshape(16, 1) | |
| w = np.array([ | |
| [1,0,0], | |
| [0,1,0], | |
| [0,0,1] | |
| ], dtype=np.float32) | |
| m = np.array([ | |
| [1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0], | |
| [0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0], | |
| [0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0], | |
| [0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1] | |
| ], dtype=np.float32) | |
| a = np.dot(m, a_prev).reshape(2,2) | |
| da_prev = np.zeros((4, 4)) | |
| # backward pass of convolution | |
| for i in range(2): | |
| v_start = 1 * i | |
| v_end = v_start + 3 | |
| for j in range(2): | |
| h_start = 1 * j | |
| h_end = h_start + 3 | |
| da_prev[v_start:v_end, h_start:h_end] += w | |
| print("Using backward:\n", da_prev) | |
| # backward pass using transposed conv | |
| print("Using transposed conv:\n", np.dot(m.T, tf.ones_like(a.reshape(4, 1))).reshape(4,4)) | |
| np.testing.assert_array_almost_equal(da_prev, (np.dot(m.T, tf.ones_like(a.reshape(4, 1))).reshape(4,4))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment