Last active
October 23, 2024 04:00
-
-
Save ProGamerGov/d032aa6780d8ef234f3ce67b177f3c14 to your computer and use it in GitHub Desktop.
A PyTorch function that matches the histogram of one image to another image, and should hopefully be helpful for individuals with use cases like astronomy & neural style transfer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
def color_transfer( | |
input: torch.Tensor, | |
source: torch.Tensor, | |
mode: str = "pca", | |
eps: float = 1e-5, | |
) -> torch.Tensor: | |
""" | |
Transfer the colors from one image tensor to another, so that the target image's | |
histogram matches the source image's histogram. Applications for image histogram | |
matching includes neural style transfer and astronomy. | |
The source image is not required to have the same height and width as the target | |
image. Batch and channel dimensions are required to be the same for both inputs. | |
Gatys, et al., "Controlling Perceptual Factors in Neural Style Transfer", arXiv, 2017. | |
https://arxiv.org/abs/1611.07865 | |
Args: | |
input (torch.Tensor): The NCHW or CHW image to transfer colors from source | |
image to from the source image. | |
source (torch.Tensor): The NCHW or CHW image to transfer colors from to the | |
input image. | |
mode (str): The color transfer mode to use. One of 'pca', 'cholesky', or 'sym'. | |
Default: "pca" | |
eps (float): The desired epsilon value to use. | |
Default: 1e-5 | |
Returns: | |
matched_image (torch.tensor): The NCHW input image with the colors of source | |
image. Outputs should ideally be clamped to the desired value range to | |
avoid artifacts. | |
""" | |
assert input.dim() == 3 or input.dim() == 4 | |
assert source.dim() == 3 or source.dim() == 4 | |
input = input.unsqueeze(0) if input.dim() == 3 else input | |
source = source.unsqueeze(0) if source.dim() == 3 else source | |
assert input.shape[:2] == source.shape[:2] | |
# Handle older versions of PyTorch | |
torch_cholesky = ( | |
torch.linalg.cholesky if torch.__version__ >= "1.9.0" else torch.cholesky | |
) | |
def torch_symeig_eigh(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
torch.symeig() was deprecated in favor of torch.linalg.eigh() | |
""" | |
if torch.__version__ >= "1.9.0": | |
L, V = torch.linalg.eigh(x, UPLO="U") | |
else: | |
L, V = torch.symeig(x, eigenvectors=True, upper=True) | |
return L, V | |
def get_mean_vec_and_cov( | |
x_input: torch.Tensor, eps: float | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Convert input images into a vector, subtract the mean, and calculate the | |
covariance matrix of colors. | |
""" | |
x_mean = x_input.mean(3).mean(2)[:, :, None, None] | |
# Subtract the color mean and convert to a vector | |
B, C = x_input.shape[:2] | |
x_vec = (x_input - x_mean).reshape(B, C, -1) | |
# Calculate covariance matrix | |
x_cov = torch.bmm(x_vec, x_vec.permute(0, 2, 1)) / x_vec.shape[2] | |
# This line is only important if you get artifacts in the output image | |
x_cov = x_cov + (eps * torch.eye(C, device=x_input.device)[None, :]) | |
return x_mean, x_vec, x_cov | |
def pca(x: torch.Tensor) -> torch.Tensor: | |
"""Perform principal component analysis""" | |
eigenvalues, eigenvectors = torch_symeig_eigh(x) | |
e = torch.sqrt(torch.diag_embed(eigenvalues.reshape(eigenvalues.size(0), -1))) | |
# Remove any NaN values if they occur | |
if torch.isnan(e).any(): | |
e = torch.where(torch.isnan(e), torch.zeros_like(e), e) | |
return torch.bmm(torch.bmm(eigenvectors, e), eigenvectors.permute(0, 2, 1)) | |
# Collect & calculate required values | |
_, input_vec, input_cov = get_mean_vec_and_cov(input, eps) | |
source_mean, _, source_cov = get_mean_vec_and_cov(source, eps) | |
# Calculate new cov matrix for input | |
if mode == "pca": | |
new_cov = torch.bmm(pca(source_cov), torch.inverse(pca(input_cov))) | |
elif mode == "cholesky": | |
new_cov = torch.bmm( | |
torch_cholesky(source_cov), torch.inverse(torch_cholesky(input_cov)) | |
) | |
elif mode == "sym": | |
p = pca(input_cov) | |
pca_out = pca(torch.bmm(torch.bmm(p, source_cov), p)) | |
new_cov = torch.bmm(torch.bmm(torch.inverse(p), pca_out), torch.inverse(p)) | |
else: | |
raise ValueError( | |
"mode has to be one of 'pca', 'cholesky', or 'sym'." | |
+ " Received '{}'.".format(mode) | |
) | |
# Multiply input vector by new cov matrix | |
new_vec = torch.bmm(new_cov, input_vec) | |
# Reshape output vector back to input's shape & | |
# add the source mean to our output vector | |
return new_vec.reshape(input.shape) + source_mean | |
# Example for standard PyTorch images with value ranges of [0-1] | |
matched_image = color_transfer(target_image, source_image).clamp(0, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I shared this function here: pytorch/vision#598 a while back, but I figured I'd post it here as well. It's similar to the histogram matching from my neural-tools project, but it works in PyTorch with all the bells and whistles like autograd.
The inner functions can be eliminated easily for TorchScript / JIT compatibility, and it's fully autograd compatible.