Last active
May 3, 2026 19:42
-
-
Save GamingLiamStudios/c8645aaf99935e28a13b60428d566a78 to your computer and use it in GitHub Desktop.
Lowpass Detection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # This script is not meant to find a lowpass for you. It is supposed to assist in finding frame ranges of different lowpasses. | |
| # Please don't blindly trust the script and check your delowpasses. | |
| # Also this is somewhat vibecoded, so buyer beware | |
| from vstools import core, vs | |
| from vsmuxtools import src_file, SourceFilter | |
| import itertools | |
| import tqdm | |
| import numpy as np | |
| from numpy.lib.stride_tricks import sliding_window_view | |
| import torch # For GPU acceleration, but can be replaced with just numpy/cupy if you'd like | |
| core.set_affinity(max_cache=12000) | |
| lowpassed = src_file( | |
| "00000.m2ts", | |
| sourcefilter=SourceFilter.BS, | |
| ).init_cut(32) | |
| clean = src_file( | |
| "00001.m2ts", | |
| sourcefilter=SourceFilter.BS, | |
| ).init_cut(32) | |
| class FilterModel: | |
| def __init__(self, vector): | |
| self.mean = vector.clone() | |
| self.var = torch.ones_like(self.mean) * 1e-3 | |
| self.count = 1 | |
| def update(self, vector, alpha=0.05): | |
| # exponential moving average | |
| delta = vector - self.mean | |
| self.mean += alpha * delta | |
| self.var = (1 - alpha) * self.var + alpha * (delta ** 2) | |
| self.count += 1 | |
| def distance(self, vector): | |
| # Mahalanobis-like (diagonal covariance) | |
| return torch.mean((vector - self.mean) ** 2 / (self.var + 1e-6)) | |
| class FilterTracker: | |
| def __init__(self, threshold=50.0, patience=10): | |
| self.model = None | |
| self.models = [] | |
| self.threshold = threshold | |
| self.patience = patience | |
| self.buffer = [] # candidate new frames | |
| self.switches = [] | |
| self.total = 0 | |
| def get_model(self): | |
| return self.models[self.model] | |
| def process(self, vector): | |
| self.total += 1 | |
| if self.model is None: | |
| self.models = [FilterModel(vector)] | |
| self.model = 0 | |
| return | |
| d = self.get_model().distance(vector) | |
| if d < self.threshold: | |
| # fits current model → accept | |
| self.get_model().update(vector) | |
| self.buffer.clear() | |
| else: | |
| # possible new filter | |
| self.buffer.append(vector) | |
| if len(self.buffer) >= self.patience: | |
| # confirm new filter | |
| new_model = FilterModel(self.buffer[0]) | |
| # initialize from buffered frames | |
| for f in self.buffer[1:]: | |
| new_model.update(f) | |
| # Check against known models | |
| lowest = 10000 | |
| for i, model in enumerate(self.models): | |
| dist = model.distance(new_model.mean) | |
| lowest = min(lowest, dist) | |
| if dist < self.threshold: | |
| self.model = i | |
| for f in self.buffer[1:]: | |
| self.get_model().update(f) | |
| print(f"Frame {self.total}: Assigned to existing model {i}") | |
| self.buffer.clear() | |
| return | |
| self.models.append(new_model) | |
| self.model = len(self.models) - 1 | |
| self.switches.append((self.total - len(self.buffer), self.model)) | |
| print(f"Frame {self.total}: Created new model {self.model}, closest distance: {lowest:.2f}") | |
| self.buffer.clear() | |
| def torch_indices(shape, device=None): | |
| grids = torch.meshgrid( | |
| *[torch.arange(s, device=device) for s in shape], | |
| indexing='ij' | |
| ) | |
| return torch.stack(grids) | |
| def radial_average(spec): | |
| spec = torch.abs(spec) ** 2 | |
| h, w = spec.shape | |
| cy, cx = h // 2, w // 2 | |
| y, x = torch_indices((h, w), device=spec.device) | |
| r = torch.sqrt((x - cx)**2 + (y - cy)**2).to(torch.int) | |
| tbin = torch.bincount(r.ravel(), spec.ravel()) | |
| nr = torch.bincount(r.ravel()) | |
| return tbin / (nr + 1e-8) | |
| # Change this to range you'd like to check | |
| clean = clean[:1000] | |
| lowpassed = lowpassed[:1000] | |
| # Tuneables | |
| batch_size = 2 | |
| patch_size = 64 | |
| step = patch_size // 2 | |
| tracker = FilterTracker() # Look at definition to see parameters for this | |
| w = torch.windows.hamming(patch_size) | |
| window = torch.outer(w, w).to('cuda') | |
| def extract_plane(clip, plane=0): # Set this to the plane you'd like to analyze | |
| for frame in clip.frames(): | |
| array = np.asarray(frame[plane], dtype=np.float32).reshape((clip.height, clip.width)) | |
| yield array | |
| iterator = tqdm.tqdm(zip( | |
| itertools.batched(extract_plane(clean), batch_size), | |
| itertools.batched(extract_plane(lowpassed), batch_size) | |
| ), total=len(clean)) | |
| for batch_clean, batch_lowpassed in iterator: | |
| # Split window into patches for a better FFT of the whole frame | |
| lowpassed_patches = sliding_window_view(batch_lowpassed, (patch_size, patch_size), axis=(-2, -1))[:, ::step, ::step] | |
| clean_patches = sliding_window_view(batch_clean, (patch_size, patch_size), axis=(-2, -1))[:, ::step, ::step] | |
| # Compute FFT for all frames at once | |
| patches = np.stack([lowpassed_patches, clean_patches]) | |
| fft_tensor = torch.from_numpy(patches).to('cuda') * window | |
| fft_result = torch.fft.fft2(fft_tensor) | |
| [lowpassed_fft, clean_fft] = fft_result | |
| # For each frame of patches, Cross spectrum method to compute frequency response | |
| S_xx = torch.conj(clean_fft) * clean_fft | |
| S_xy = torch.conj(clean_fft) * lowpassed_fft | |
| S_xx = torch.mean(S_xx, dim=(1, 2)) | |
| S_xy = torch.mean(S_xy, dim=(1, 2)) | |
| H = S_xy / (S_xx + 1e-8) | |
| for profile in H: | |
| profile = radial_average(profile) | |
| profile = torch.log1p(profile) | |
| profile = profile / (torch.linalg.norm(profile) + 1e-8) | |
| tracker.process(profile) | |
| iterator.update(len(batch_clean)) | |
| iterator.close() | |
| print(f"{tracker.total} Frames: {len(tracker.models)} Models, {len(tracker.switches)} Switches") | |
| for frame, model in tracker.switches: | |
| print(f"\t- Frame {frame}: Model {model}") | |
| # Print models in same graph | |
| import matplotlib.pyplot as plt | |
| plt.figure() | |
| for i, model in enumerate(tracker.models): | |
| plt.plot(model.mean.cpu().numpy(), label=f"Model {i}") | |
| plt.legend() | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment