Created
September 9, 2024 03:04
-
-
Save minh-nguyenhoang/2e06aef032caa2d525822230cbef66cc to your computer and use it in GitHub Desktop.
Color transfer using reinhard's algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
RGB2LMS = np.array([[0.3811, 0.5783, 0.0402], | |
[0.1967, 0.7244, 0.0782], | |
[0.0241, 0.1288, 0.8444]]).astype(np.float32) | |
LMS2LAB = np.array([[1/np.sqrt(3), 1/np.sqrt(3), 1/np.sqrt(3)], | |
[1/np.sqrt(6), 1/np.sqrt(6), -2/np.sqrt(6)], | |
[1/np.sqrt(2), -1/np.sqrt(2), 0]]).astype(np.float32) | |
INV_LMS2LAB = np.linalg.inv(LMS2LAB) | |
INV_RGB2LMS = np.linalg.inv(RGB2LMS) | |
def cvt_RGB2LAB(img): | |
img = np.maximum(img, 1.) | |
img_lms = cv2.transform(img, RGB2LMS) | |
# img = np.maximum(img, 1./255) | |
img_lms = np.log(img_lms).astype(np.float32)/np.log(10) | |
img_lab = cv2.transform(img_lms, LMS2LAB) | |
return img_lab | |
def cvt_LAB2RGB(img): | |
img_lms = cv2.transform(img, INV_LMS2LAB) | |
# img_lms = np.power(10, img_lms) | |
img_lms = np.exp(img_lms).astype(np.float32) | |
img_lms = np.power(img_lms, np.log(10.0)) | |
img_rgb = cv2.transform(img_lms, INV_RGB2LMS) | |
return img_rgb | |
def channel_conditioning(t_channel, s_channel): | |
wval = float(0.25) | |
_, mask = cv2.threshold(s_channel, 0, 1, cv2.THRESH_BINARY) | |
mask = mask.astype(np.uint8) | |
s_mean_U = np.sum(s_channel*mask, axis= (0,1))/np.maximum(np.sum(mask, axis= (0,1)), 1) ## mean for element above 0 | |
w_U = np.exp(-s_channel * wval/s_mean_U) | |
w_U = (1 - w_U) * (1 - w_U) | |
w_mean = np.sum(w_U*mask, axis= (0,1))/np.maximum(np.sum(mask, axis= (0,1)), 1) ## mean for element above 0 | |
channel_U = np.power(s_channel, 4) | |
s_mean_U = np.sum(channel_U*w_U*mask, axis= (0,1))/np.maximum(np.sum(mask, axis= (0,1)), 1)/w_mean | |
inv_mask = 1 - mask | |
s_mean_L = np.sum(s_channel*inv_mask, axis= (0,1))/np.maximum(np.sum(inv_mask, axis= (0,1)), 1) ## mean for element above 0 | |
w_L = np.exp(-s_channel * wval/s_mean_L) | |
w_L = (1 - w_L) * (1 - w_L) | |
w_mean = np.sum(w_L*inv_mask, axis= (0,1))/np.maximum(np.sum(inv_mask, axis= (0,1)), 1) ## mean for element above 0 | |
channel_L = np.power(s_channel, 4) | |
s_mean_L = np.sum(channel_L*w_L*inv_mask, axis= (0,1))/np.maximum(np.sum(inv_mask, axis= (0,1)), 1)/w_mean | |
_, mask = cv2.threshold(t_channel, 0, 1, cv2.THRESH_BINARY) | |
mask = mask.astype(np.uint8) | |
t_mean_U = np.sum(t_channel*mask, axis= (0,1))/np.maximum(np.sum(mask, axis= (0,1)), 1) ## mean for element above 0 | |
w_U = np.exp(-t_channel * wval/t_mean_U) | |
w_U = (1 - w_U) * (1 - w_U) | |
w_mean = np.sum(w_U*mask, axis= (0,1))/np.maximum(np.sum(mask, axis= (0,1)), 1) + 1e-6 ## mean for element above 0 | |
channel_U = np.power(t_channel, 4) | |
t_mean_U = np.sum(channel_U*w_U*mask, axis= (0,1))/np.maximum(np.sum(mask, axis= (0,1)), 1)/w_mean | |
inv_mask = 1 - mask | |
t_mean_L = np.sum(t_channel*inv_mask, axis= (0,1))/np.maximum(np.sum(inv_mask, axis= (0,1)), 1) ## mean for element above 0 | |
w_L = np.exp(-t_channel * wval/t_mean_L) | |
w_L = (1 - w_L) * (1 - w_L) | |
w_mean = np.sum(w_L*inv_mask, axis= (0,1))/np.maximum(np.sum(inv_mask, axis= (0,1)), 1) ## mean for element above 0 | |
channel_L = np.power(t_channel, 4) | |
t_mean_L = np.sum(channel_L*w_L*inv_mask, axis= (0,1))/np.maximum(np.sum(inv_mask, axis= (0,1)), 1)/w_mean | |
k = np.sqrt(np.sqrt(s_mean_U/t_mean_U)) | |
t_channel_U = (1 + w_U*(k-1))*t_channel | |
k = np.sqrt(np.sqrt(s_mean_L/t_mean_L)) | |
t_channel_L = (1 + w_L*(k-1))*t_channel | |
t_chanel_normalized: np.ndarray = t_channel_U * mask + t_channel_L * inv_mask | |
# t_mean, t_std = cv2.meanStdDev(t_chanel_normalized) | |
# t_mean, t_std = t_mean.reshape(-1), t_std.reshape(-1) | |
t_mean, t_std = t_chanel_normalized.mean((0,1)), t_chanel_normalized.std((0,1)) | |
t_chanel_normalized = (t_chanel_normalized - t_mean)/t_std | |
return t_chanel_normalized | |
def adjust_covariance(t_lab: np.ndarray, s_lab: np.ndarray, cross_covariance_limit = 0.5): | |
t_lab = t_lab.copy() | |
if cross_covariance_limit != 0: | |
tcrosscorr = np.mean(t_lab[...,1] * t_lab[...,2], axis= (0,1)) | |
scrosscorr = np.mean(s_lab[...,1] * s_lab[...,2], axis= (0,1)) | |
W1 = 0.5 * np.sqrt((1 + scrosscorr) / (1 + tcrosscorr)) + \ | |
0.5 * np.sqrt((1 - scrosscorr) / (1 - tcrosscorr)) | |
W2 = 0.5 * np.sqrt((1 + scrosscorr) / (1 + tcrosscorr)) - \ | |
0.5 * np.sqrt((1 - scrosscorr) / (1 - tcrosscorr)) | |
if abs(W2) > cross_covariance_limit * abs(W1): | |
W2 = np.copysign(cross_covariance_limit * W1, W2) | |
norm = 1.0 / np.sqrt(W1**2 + W2**2 + 2 * W1 * W2 * tcrosscorr) | |
W1 *= norm | |
W2 *= norm | |
z1 = t_lab[...,1].copy() | |
t_lab[...,1] = W1 * z1 + W2 * t_lab[...,2] | |
t_lab[...,2] = W1 * t_lab[...,2] + W2 * z1 | |
return t_lab | |
def core_processing(tgt, src, cross_covariance_limit = 0.5, reshaping_iteration = 1, shader_val = 0.5): | |
max_val = np.max(tgt) | |
tgtf = cvt_RGB2LAB(tgt) | |
srcf = cvt_RGB2LAB(src) | |
tgt_mean, tgt_std = cv2.meanStdDev(tgtf) | |
src_mean, src_std = cv2.meanStdDev(srcf) | |
tgt_mean, tgt_std = tgt_mean.reshape(-1).astype(np.float32), tgt_std.reshape(-1).astype(np.float32) | |
src_mean, src_std = src_mean.reshape(-1).astype(np.float32), src_std.reshape(-1).astype(np.float32) | |
t_lab = (tgtf - tgt_mean)/tgt_std | |
s_lab = (srcf - src_mean)/src_std | |
for j in range(reshaping_iteration,(reshaping_iteration + 1)//2, -1): | |
t_lab[...,1:] = channel_conditioning(t_lab[...,1:], s_lab[...,1:]) | |
# print(t_lab.mean((0,1)), t_lab.std((0,1))) | |
t_lab = adjust_covariance(t_lab, s_lab, cross_covariance_limit) | |
for j in range((reshaping_iteration + 1)//2): | |
t_lab[...,1:] = channel_conditioning(t_lab[...,1:], s_lab[...,1:]) | |
# print(t_lab.mean((0,1)), t_lab.std((0,1))) | |
src_mean, src_std = src_mean.copy(), src_std.copy() | |
src_mean[0] = shader_val*src_mean[0] + (1-shader_val)*tgt_mean[0] | |
src_std[0] = shader_val*src_std[0] + (1-shader_val)*tgt_std[0] | |
t_lab = t_lab * src_std + src_mean | |
res_rgb = np.clip(cvt_LAB2RGB(t_lab), 0, max_val).astype(np.uint8) | |
return res_rgb | |
def adjust_saturation(img, origin_img, saturation_val = -1): | |
# Convert images from BGR to HSV color space | |
img: np.ndarray = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) | |
origin_img: np.ndarray = cv2.cvtColor(origin_img, cv2.COLOR_RGB2HSV) | |
if saturation_val < 0: | |
# Calculate saturation_val as the ratio of the max saturation value | |
# in the original image to the max value in the processed image | |
amax1 = np.max(img[...,1]) | |
amax2 = np.max(origin_img[...,1]) | |
saturation_val = amax2 / amax1 | |
if saturation_val != 1.: | |
# Compute weighted mix of the processed target and original saturation channels | |
# origin_img[...,1] = cv2.addWeighted(img[...,1], saturation_val, origin_img[...,1], 1 - saturation_val, 0.0) | |
origin_img[...,1] = img[...,1]*(saturation_val) + origin_img[...,1]*(1-saturation_val) | |
# Create a mask where the processed image's saturation exceeds the original target image's saturation | |
mask = cv2.threshold(img[...,1] - origin_img[...,1], 0, 1, cv2.THRESH_BINARY)[1] | |
mask = mask.astype(np.uint8) | |
# plt.imshow(mask * 255, cmap = 'gray') | |
# plt.show() | |
# Create the modified reference saturation channel | |
origin_img[...,1] = origin_img[...,1] * mask + img[...,1]*(1-mask) | |
# Match the mean and standard deviation of the processed image's saturation channel | |
# to the modified reference saturation channel | |
# tmean, tdev = cv2.meanStdDev(img[...,1]) | |
# tmpmean, tmpdev = cv2.meanStdDev(origin_img[...,1]) | |
tmean, tdev = img[...,1].mean((0,1)), img[...,1].std((0,1)) | |
tmpmean, tmpdev = origin_img[...,1].mean((0,1)), origin_img[...,1].std((0,1)) | |
tmp_ = img[...,1].copy().astype(np.float32) | |
img[...,1] = np.clip((tmp_ - tmean) / tdev * tmpdev + tmpmean, 0, 255).astype(np.uint8) | |
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) | |
return img | |
def full_shading(img, ori_img, src_img, extra_shading = True, shader_val = 0.5): | |
# Matches the grey shade distribution of the modified target image | |
# to that of a notional shader image which is a linear combination | |
# of the original target and source image as determined by 'shader_val'. | |
if extra_shading: | |
# Convert images to grayscale | |
greyt = np.dot(ori_img[..., :3], [0.2989, 0.5870, 0.1140]) | |
greyp = np.dot(img[..., :3], [0.2989, 0.5870, 0.1140]) | |
greys = np.dot(src_img[..., :3], [0.2989, 0.5870, 0.1140]) | |
# Standardize the grayscale images for the source and target | |
smean, sdev = np.mean(greys), np.std(greys) | |
tmean, tdev = np.mean(greyt), np.std(greyt) | |
greyt = (greyt - tmean) / tdev | |
# Rescale the standardized grayscale target image | |
greyt = greyt * (shader_val * sdev + (1.0 - shader_val) * tdev) \ | |
+ shader_val * smean + (1.0 - shader_val) * tmean | |
# Ensure there are no zero or negative values in the grayscale images | |
min_val = 1. #/ 255.0 | |
greyp = np.maximum(greyp, min_val) # Guard against zero divide | |
greyt = np.maximum(greyt, 0.0) # Guard against negative values | |
# Rescale each color channel of the processed image | |
img = np.clip(img / greyp[..., None] * greyt[..., None], 0, 255).astype(np.uint8) | |
return img | |
def final_adjustment(img, ori_img, tint_val = 1., modified_val = 1.): | |
# Implements a change to the tint of the final image and | |
# its degree of modification if a change is specified. | |
# If 100% tint is not specified, compute a weighted average | |
# of the processed image and its grayscale representation. | |
if tint_val != 1.0: | |
# Convert to grayscale | |
grey = np.dot(img[..., :3], [0.2989, 0.5870, 0.1140])[..., None] | |
# Apply the tint by adjusting each channel | |
img = tint_val * img + (1.0 - tint_val) * grey | |
# If 100% image modification is not specified, compute a weighted average | |
# of the processed image and the original target image. | |
if modified_val != 1.0: | |
img = modified_val * img + (1.0 - modified_val) * ori_img | |
img = np.clip(img, 0, 255).astype(np.uint8) | |
return img | |
if __name__ == "__main__": | |
import matplotlib.pyplot as plt | |
_src = cv2.imread(r'Images\Vase_source.jpg')[..., ::-1].copy() | |
_tgt = cv2.imread(r'Images\Vase_target.jpg')[..., ::-1].copy() | |
reshaping_iteration = 1 | |
cross_covariance_limit = .5 | |
shader_val = .5 | |
saturation_val = -1. | |
extra_shading = True | |
tint_val = 1. | |
modified_val = 1. | |
view_row = False | |
_core = core_processing(_tgt,_src, reshaping_iteration= reshaping_iteration, cross_covariance_limit= cross_covariance_limit, shader_val= shader_val) | |
_res = adjust_saturation(_core, _tgt, saturation_val= saturation_val) | |
_shaded = full_shading(_res, _tgt, _src, extra_shading= extra_shading, shader_val = shader_val) | |
_final = final_adjustment(_shaded, _tgt, tint_val= tint_val, modified_val= modified_val) | |
if view_row: | |
fig, ax = plt.subplots(1,6, figsize = (30,5)) | |
fig.add_artist(plt.Line2D([0.375, 0.375], [0.1,0.9], transform=fig.transFigure, color="black")) | |
ax[0].imshow(_tgt) | |
ax[0].set_title("Target Image") | |
ax[0].axis("off") | |
ax[1].imshow(_src) | |
ax[1].set_title("Palette Image") | |
ax[1].axis("off") | |
ax[2].imshow(_core) | |
ax[2].set_title("After core processing") | |
ax[2].axis("off") | |
ax[3].imshow(_res) | |
ax[3].set_title("After saturation processing") | |
ax[3].axis("off") | |
ax[4].imshow(_shaded) | |
ax[4].set_title("After full shading") | |
ax[4].axis("off") | |
ax[5].imshow(_final) | |
ax[5].set_title("Final Image") | |
ax[5].axis("off") | |
# fig.show() | |
else: | |
fig, ax = plt.subplots(2,3, figsize = (20,10)) | |
fig.add_artist(plt.Line2D([0.375, 0.375], [0.1,0.9], transform=fig.transFigure, color="black")) | |
ax[0,0].imshow(_tgt) | |
ax[0,0].set_title("Target Image") | |
ax[0,0].axis("off") | |
ax[1,0].imshow(_src) | |
ax[1,0].set_title("Palette Image") | |
ax[1,0].axis("off") | |
ax[0,1].imshow(_core) | |
ax[0,1].set_title("After core processing") | |
ax[0,1].axis("off") | |
ax[0,2].imshow(_res) | |
ax[0,2].set_title("After saturation processing") | |
ax[0,2].axis("off") | |
ax[1,1].imshow(_shaded) | |
ax[1,1].set_title("After full shading") | |
ax[1,1].axis("off") | |
ax[1,2].imshow(_final) | |
ax[1,2].set_title("Final Image") | |
ax[1,2].axis("off") | |
# fig.show() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment