-
-
Save arch1t3cht/2886dfccd070c50ef77a32a88f9e0ae5 to your computer and use it in GitHub Desktop.
Mathematically solve for convolution kernels
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from vstools import vs, core, get_y, depth, set_output | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| from tqdm import tqdm | |
| import itertools | |
| # Assumes that the convolution is horizontally and vertically symmetric | |
| # (though it could easily be modified to not do that, it'd just be slower). | |
| clip1 = core.lsmas.LWLibavSource("delowpasssample_1.mkv") # clean clip | |
| clip2 = core.lsmas.LWLibavSource("delowpasssample_2.mkv") # lowpassed clip | |
| clip1 = depth(get_y(clip1), 32) | |
| clip2 = depth(get_y(clip2), 32) | |
| set_output([clip1, clip2]) | |
| w = clip1.width | |
| h = clip1.height | |
| fr = 1027 # pick a frame | |
| frame1 = np.array(clip1.get_frame(fr), dtype=np.float32).reshape((h, w)) | |
| frame2 = np.array(clip2.get_frame(fr), dtype=np.float32).reshape((h, w)) | |
| # empirically most lowpass kernels have a support < 16 | |
| # lower support makes calculations faster but it's probably not safe to go lower than 16 | |
| support = 16 | |
| frame1_padded = frame1 | |
| frame1_padded = np.vstack((np.flipud(frame1_padded[:support,:]), frame1_padded, np.flipud(frame1_padded[-support:,:]))) | |
| frame1_padded = np.hstack((np.fliplr(frame1_padded[:,:support]), frame1_padded, np.fliplr(frame1_padded[:,-support:]))) | |
| eqns = np.zeros((h, w, support, support)) | |
| for dx, dy in tqdm(list(itertools.product(range(-support + 1, support), range(-support + 1, support)))): | |
| eqns[:,:,abs(dy),abs(dx)] += frame1_padded[support+dy:,support+dx:][:h,:w] | |
| eqns = eqns.reshape((h * w, support ** 2)) | |
| kernel, residuals, rank, sings = np.linalg.lstsq(eqns, frame2.reshape((w * h,)), rcond=None) | |
| kernel = kernel.reshape((support, support)) | |
| print(kernel) | |
| # plt.stem(list(range(support)), kernel[0,:]) | |
| # plt.stem(list(range(support)), kernel[:,0]) | |
| # plt.ylim((-1, 1)) | |
| # plt.show() | |
| print(f"kernel MSE: {np.sum((eqns.dot(kernel.ravel()) - frame2.ravel()) ** 2)}") | |
| # We've found a 2d convolution kernel. Now, assume that it's separable and arises from a single symmetric 1d kernel, | |
| # and find the best such kernel using gradient descent | |
| def kernel_1dto2d(kernel): | |
| return np.outer(kernel, kernel) | |
| def kernel_1dto2d_deriv(kernel_1d, kernel_2d): | |
| s = len(kernel_1d) | |
| return np.array([ | |
| # d/dk_x sum((kernel_1d[y] * kernel_1d[x] - kernel_2d[y,x]) ** 2 for x in range(s) for y in range(s)) | |
| # = d/dk_x (sum((kernel_1d[k] * kernel_1d[y] - kernel_2d[y,k]) ** 2 for y in range(s) if y != k)) | |
| # + (sum((kernel_1d[x] * kernel_1d[k] - kernel_2d[k,x]) ** 2 for x in range(s) if x != k)) | |
| # + (kernel_1d[k] ** 2 - kernel_2d[k,k]) ** 2 | |
| # = (sum(2 * kernel_1d[y] * (kernel_1d[k] * kernel_1d[y] - kernel_2d[y,k]) for y in range(s) if x != k)) | |
| # + (sum(2 * kernel_1d[x] * (kernel_1d[x] * kernel_1d[k] - kernel_2d[k,x]) for x in range(s) if y != k)) | |
| # + 4 * kernel_1d[k] * (kernel_1d[k] ** 2 - kernel_2d[k,k]) | |
| # = | |
| (sum(2 * kernel_1d[y] * (kernel_1d[k] * kernel_1d[y] - kernel_2d[y,k]) for y in range(s))) | |
| + (sum(2 * kernel_1d[x] * (kernel_1d[x] * kernel_1d[k] - kernel_2d[k,x]) for x in range(s))) | |
| for k in range(s)]) | |
| print(f"1dto2d MSE: {np.sum((kernel_1dto2d(kernel[0,:]) - kernel) ** 2)}") | |
| print(f"1dto2d MSE: {np.sum((kernel_1dto2d(kernel[:,0]) - kernel) ** 2)}") | |
| print("Gradient descent:") | |
| kernel_1d = kernel[0,:] | |
| stepsize = np.sqrt(np.sum((kernel_1dto2d(kernel[0,:]) - kernel) ** 2)) / 2 # idfk | |
| for i in range(100): | |
| kernel_1d -= stepsize * kernel_1dto2d_deriv(kernel_1d, kernel) | |
| print(f"MSE: {np.sum((kernel_1dto2d(kernel[:,0]) - kernel) ** 2)}") | |
| print(kernel_1d) | |
| plt.stem(list(range(support)), kernel_1d) | |
| plt.ylim((-1, 1)) | |
| plt.show() | |
| # Now you could fit some known kernel (say lanczos with some taps+blur) to kernel_1d if you wanted to |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment