Skip to content

Instantly share code, notes, and snippets.

@huangsam
Created December 31, 2025 19:58
Show Gist options
  • Select an option

  • Save huangsam/4906ca328a4edcb0b88344f3323c5f63 to your computer and use it in GitHub Desktop.

Select an option

Save huangsam/4906ca328a4edcb0b88344f3323c5f63 to your computer and use it in GitHub Desktop.
Comparing an image with itself using PyTorch
import torch
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
# Load images
image1 = Image.open("images/image1.jpg").convert("RGB")
image2 = Image.open("images/image1.jpg").convert("RGB") # Comparing to itself
# Transform to tensors, scale to 0-255 for SSIM
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)])
img1 = transform(image1).unsqueeze(0) # Add batch dimension: [1, 3, H, W]
img2 = transform(image2).unsqueeze(0)
# SSIM function (simplified implementation for learning)
def ssim(img1: torch.Tensor, img2: torch.Tensor, max_val: float = 255.0, window_size: int = 11) -> torch.Tensor:
"""
Compute the Structural Similarity Index (SSIM) between two images.
SSIM measures image similarity by comparing luminance, contrast, and structure.
It returns a value between -1 and 1, where 1 indicates perfect similarity.
The formula is:
SSIM(x,y) = (2*mu_x*mu_y + c1)*(2*sigma_xy + c2) / ((mu_x^2 + mu_y^2 + c1)*(sigma_x^2 + sigma_y^2 + c2))
Steps:
1. Create a box filter kernel to approximate Gaussian weighting for local statistics.
2. Compute local means (mu) for both images using convolution.
3. Calculate local variances (sigma^2) and covariance (sigma_12) using the means.
4. Add stability constants (c1, c2) to prevent division by zero.
5. Compute the SSIM numerator and denominator per the formula.
6. Calculate the SSIM map and return the mean value.
Args:
img1 (torch.Tensor): First image tensor with shape [B, C, H, W].
img2 (torch.Tensor): Second image tensor with shape [B, C, H, W].
max_val (float): Maximum pixel value (default 255.0 for 8-bit images).
window_size (int): Size of the sliding window for local statistics (default 11).
Returns:
torch.Tensor: Mean SSIM value (scalar tensor).
"""
# Create a box filter kernel to approximate Gaussian weighting for local statistics
# Kernel shape: [3, 1, window_size, window_size] for RGB channels (groups=3 in conv)
kernel: torch.Tensor = torch.ones(3, 1, window_size, window_size) / (3 * window_size**2)
kernel = kernel.to(img1.device)
# Compute local means (mu) by convolving with the kernel
mu1: torch.Tensor = F.conv2d(img1, kernel, groups=3, padding=window_size // 2)
mu2: torch.Tensor = F.conv2d(img2, kernel, groups=3, padding=window_size // 2)
# Precompute squares and products of means for variance/covariance calculations
mu1_sq: torch.Tensor = mu1**2
mu2_sq: torch.Tensor = mu2**2
mu1_mu2: torch.Tensor = mu1 * mu2
# Compute local variances (sigma^2) and covariance (sigma_12)
# Variance: E[X²] - (E[X])²
sigma1_sq: torch.Tensor = F.conv2d(img1**2, kernel, groups=3, padding=window_size // 2) - mu1_sq
sigma2_sq: torch.Tensor = F.conv2d(img2**2, kernel, groups=3, padding=window_size // 2) - mu2_sq
sigma12: torch.Tensor = F.conv2d(img1 * img2, kernel, groups=3, padding=window_size // 2) - mu1_mu2
# Stability constants to prevent division by zero when local variances are small
c1: float = (0.01 * max_val) ** 2
c2: float = (0.03 * max_val) ** 2
# Compute SSIM numerator and denominator per the formula
numerator: torch.Tensor = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
denominator: torch.Tensor = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
# SSIM map: element-wise division, then average over all pixels/channels for final score
ssim_map: torch.Tensor = numerator / denominator
return ssim_map.mean()
def main():
# Calculate SSIM
ssim_value = ssim(img1, img2)
print(f"SSIM: {ssim_value.item()}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment