Created
December 31, 2025 19:58
-
-
Save huangsam/4906ca328a4edcb0b88344f3323c5f63 to your computer and use it in GitHub Desktop.
Comparing an image with itself using PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| # Load images | |
| image1 = Image.open("images/image1.jpg").convert("RGB") | |
| image2 = Image.open("images/image1.jpg").convert("RGB") # Comparing to itself | |
| # Transform to tensors, scale to 0-255 for SSIM | |
| transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)]) | |
| img1 = transform(image1).unsqueeze(0) # Add batch dimension: [1, 3, H, W] | |
| img2 = transform(image2).unsqueeze(0) | |
| # SSIM function (simplified implementation for learning) | |
| def ssim(img1: torch.Tensor, img2: torch.Tensor, max_val: float = 255.0, window_size: int = 11) -> torch.Tensor: | |
| """ | |
| Compute the Structural Similarity Index (SSIM) between two images. | |
| SSIM measures image similarity by comparing luminance, contrast, and structure. | |
| It returns a value between -1 and 1, where 1 indicates perfect similarity. | |
| The formula is: | |
| SSIM(x,y) = (2*mu_x*mu_y + c1)*(2*sigma_xy + c2) / ((mu_x^2 + mu_y^2 + c1)*(sigma_x^2 + sigma_y^2 + c2)) | |
| Steps: | |
| 1. Create a box filter kernel to approximate Gaussian weighting for local statistics. | |
| 2. Compute local means (mu) for both images using convolution. | |
| 3. Calculate local variances (sigma^2) and covariance (sigma_12) using the means. | |
| 4. Add stability constants (c1, c2) to prevent division by zero. | |
| 5. Compute the SSIM numerator and denominator per the formula. | |
| 6. Calculate the SSIM map and return the mean value. | |
| Args: | |
| img1 (torch.Tensor): First image tensor with shape [B, C, H, W]. | |
| img2 (torch.Tensor): Second image tensor with shape [B, C, H, W]. | |
| max_val (float): Maximum pixel value (default 255.0 for 8-bit images). | |
| window_size (int): Size of the sliding window for local statistics (default 11). | |
| Returns: | |
| torch.Tensor: Mean SSIM value (scalar tensor). | |
| """ | |
| # Create a box filter kernel to approximate Gaussian weighting for local statistics | |
| # Kernel shape: [3, 1, window_size, window_size] for RGB channels (groups=3 in conv) | |
| kernel: torch.Tensor = torch.ones(3, 1, window_size, window_size) / (3 * window_size**2) | |
| kernel = kernel.to(img1.device) | |
| # Compute local means (mu) by convolving with the kernel | |
| mu1: torch.Tensor = F.conv2d(img1, kernel, groups=3, padding=window_size // 2) | |
| mu2: torch.Tensor = F.conv2d(img2, kernel, groups=3, padding=window_size // 2) | |
| # Precompute squares and products of means for variance/covariance calculations | |
| mu1_sq: torch.Tensor = mu1**2 | |
| mu2_sq: torch.Tensor = mu2**2 | |
| mu1_mu2: torch.Tensor = mu1 * mu2 | |
| # Compute local variances (sigma^2) and covariance (sigma_12) | |
| # Variance: E[X²] - (E[X])² | |
| sigma1_sq: torch.Tensor = F.conv2d(img1**2, kernel, groups=3, padding=window_size // 2) - mu1_sq | |
| sigma2_sq: torch.Tensor = F.conv2d(img2**2, kernel, groups=3, padding=window_size // 2) - mu2_sq | |
| sigma12: torch.Tensor = F.conv2d(img1 * img2, kernel, groups=3, padding=window_size // 2) - mu1_mu2 | |
| # Stability constants to prevent division by zero when local variances are small | |
| c1: float = (0.01 * max_val) ** 2 | |
| c2: float = (0.03 * max_val) ** 2 | |
| # Compute SSIM numerator and denominator per the formula | |
| numerator: torch.Tensor = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) | |
| denominator: torch.Tensor = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) | |
| # SSIM map: element-wise division, then average over all pixels/channels for final score | |
| ssim_map: torch.Tensor = numerator / denominator | |
| return ssim_map.mean() | |
| def main(): | |
| # Calculate SSIM | |
| ssim_value = ssim(img1, img2) | |
| print(f"SSIM: {ssim_value.item()}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment