Skip to content

Instantly share code, notes, and snippets.

@pvskand
Last active May 2, 2019 05:18
Show Gist options
  • Save pvskand/11d42165e215ef1150644ee057ae97bc to your computer and use it in GitHub Desktop.
Save pvskand/11d42165e215ef1150644ee057ae97bc to your computer and use it in GitHub Desktop.
Zero Cross Correlation Layer
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import numpy
import cv2
import torch.optim as optim
from theano.tensor import *
import numpy
import theano
from theano import Op, Apply
import theano.tensor as T
from theano.gradient import grad_not_implemented
from theano.gradient import grad_undefined
'''
This function returns a patch [if valid] around the neighbourhood of a pixel.
For a pixel going out of bounds, refer to <function>
'''
ws = 3 # window (patch size)
ns = 10 # neighbourhood size
pad = 6# (ws-1)/2 # padding required
gpu_id = 0
def im2col(img1, img2, b, dim):
npad = ((0, 0),(0, 0), (pad, pad), (pad, pad))
img1 = img1.cuda()
img1_np = img1.cpu().numpy()
img1_np = np.pad(img1_np, pad_width=npad, mode='constant', constant_values=0)
img2 = img2.cuda()
img2_np = img2.cpu().numpy()
img2_np = np.pad(img2_np, pad_width=npad, mode='constant', constant_values=0)
T_ip1 = T.tensor4('input')
T_ip2 = T.matrix('input')
cons = ws
neibs = theano.tensor.nnet.neighbours.images2neibs(T_ip1, neib_shape=(ws, ws), neib_step=(1, 1), mode="ignore_borders")
neib_fn = theano.function([T_ip1],neibs)
img_1_oned = torch.Tensor(1, dim, img1.size()[2] + ns+2,img1.size()[3] + ns+2).numpy() # 258, 258
img_2_oned = torch.Tensor(1, dim, img1.size()[2] + ns+2,img1.size()[3] + ns+2).numpy() # 258, 258
img_1_oned[0, :, :, :] = img1_np[b, :, :, :]
img_2_oned[0, :, :, :] = img2_np[b, :, :, :]
np_op_1 = neib_fn(img_1_oned)
np_op_2 = neib_fn(img_2_oned)
out1 = np_op_1.reshape(1, img1.size()[2]+ns, img1.size()[3]+ns, ws*ws*dim) # 256 x 256
out2 = np_op_2.reshape(1, img1.size()[2]+ns, img1.size()[3]+ns, ws*ws*dim) # 256 x 256
out1 = np.transpose(out1, (0, 3, 1, 2))
out2 = np.transpose(out2, (0, 3, 1, 2))
print out1.shape, "shape", img1_np.shape, img_1_oned.shape
return out1, out2
''' To compute the Zero Normalized Cross Correlation between 2 patches
<write formula of ZNCC>
'''
def zncc_compute_pixel(img1, img2, b, dim):
blob1, blob2 = im2col(img1, img2, b, dim) # 1 x ws*ws*dim x h x w
mean1 = np.average(blob1, axis = 1)
mean2 = np.average(blob2, axis = 1)
std1 = np.std(blob1, axis = 1)
std2 = np.std(blob2, axis = 1)
mean11 = np.repeat(mean1[np.newaxis, :, :, :], ws*ws*dim, axis=0)
mean11 = torch.from_numpy(mean11).cuda()
mean22 = np.repeat(mean2[np.newaxis, :, :, :], ws*ws*dim, axis=0)
mean22 = torch.from_numpy(mean22).cuda()
blob1 = torch.from_numpy(blob1).cuda()
blob2 = torch.from_numpy(blob2).cuda()
mean1 = torch.from_numpy(mean1).cuda()
mean2 = torch.from_numpy(mean2).cuda()
std11 = np.repeat(std1[np.newaxis, :, :, :], ws*ws*dim, axis=0)
std22 = np.repeat(std2[np.newaxis, :, :, :], ws*ws*dim, axis=0)
std1 = torch.from_numpy(std1).cuda()
std2 = torch.from_numpy(std2).cuda()
std11 = torch.from_numpy(std11).cuda()
std22 = torch.from_numpy(std22).cuda()
ori_img1 = blob1[:, :, 4:260, 4:260]
ori_img2 = blob2[:, :, 4:260, 4:260]
output = torch.Tensor(ns*ns, img1.size()[2],img1.size()[3])
for i in range(0, 10):
for j in range(0, 10):
shift_img1 = blob1[:, :, i:256+i, j:256+j]
shift_img2 = blob2[:, :, i:256+i, j:256+j]
mean_img1 = mean11[:, :, i:256+i, j:256+j]
mean_img2 = mean22[:, :, i:256+i, j:256+j]
std_img1 = std1[:, i:256+i, j:256+j]
std_img2 = std2[:, i:256+i, j:256+j]
img_1 = shift_img1
img_1 = img_1.cpu().numpy()
img_1 = torch.from_numpy(img_1).cuda()
img_2 = shift_img2
img_2 = img_2.cpu().numpy()
img_2 = torch.from_numpy(img_2).cuda()
num = ((shift_img1 - mean_img1) * (shift_img2 - mean_img2)) + 10e-13;
num = num.sum(1)
num = num[0, :, :, :]
out = num/(((ws*ws-1) * std_img1 * std_img2) + 10e-13)
output[ns*i+j, :, :] = out[0, :, :]
# (P" - Q")^2
# P1 = (blob1 - mean11 + 10e-7)/(std11 + 10e-7)
# P2 = (blob2 - mean22 + 10e-7)/(std22 + 10e-7)
# out_zncc = (P1-P2) * (P1-P2 )
# out_zncc = out_zncc.sum(1)
# out_zncc = torch.exp(-out_zncc)
return output
''' A function to compute ZNCC between 2 images by taking patch size '''
def zncc_compute(img1, img2):
batch, dim, Imx, Imy = img1.size()
zncc_matrix = torch.zeros((batch, ns*ns, Imx, Imy)).cuda(gpu_id)
for b in range(0, batch):
corr = zncc_compute_pixel(img1, img2, b, dim)
zncc_matrix[b, :] = corr
# max_val = torch.max(zncc_matrix)
# zncc_matrix = max_val - zncc_matrix # reversing the background and foreground
print zncc_matrix.size(), "final output"
return (zncc_matrix.cuda(gpu_id))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment