Created
May 17, 2022 10:25
-
-
Save louity/f51ec9c5ffea7dbf13e6f182e73d29d4 to your computer and use it in GitHub Desktop.
Type II DCT and DST iwth PyTorch. Note that iDCT-II is DCT-III upt to normalizing constant and t iDST-II is DST-III similarly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import scipy.fftpack | |
import numpy as np | |
np.set_printoptions(precision=4, linewidth=200) | |
N = 8 | |
x = torch.DoubleTensor(8).normal_() | |
exp_vec_1 = 2 * torch.exp(-1j*torch.pi*torch.arange(N)/(2*N)) | |
exp_vec_2 = torch.exp(1j*torch.pi*torch.arange(N)/(2*N)) | |
def dctII_pt(x, exp_vec): | |
v = torch.cat([x[::2], torch.flip(x, dims=(-1,))[::2]], dim=-1) | |
V = torch.fft.fft(v) | |
return (V*exp_vec).real | |
def dstII_pt(x, exp_vec): | |
v = torch.cat([x[::2], -torch.flip(x, dims=(-1,))[::2]], dim=-1) | |
V = torch.fft.fft(v) | |
return torch.flip((V*exp_vec).real, dims=(-1,)) | |
def idctII_pt(x, exp_vec): | |
N = x.shape[-1] | |
x_rev = torch.flip(x, dims=(-1,))[:-1] | |
v = torch.cat([x[0:1], exp_vec[1:N] * (x[1:N]-1j*x_rev)]) / 2 | |
V = torch.fft.ifft(v) | |
y = torch.zeros_like(x) | |
y[::2] = V[:N//2].real; | |
y[1::2] = torch.flip(V, dims=(-1,))[:N//2].real | |
return y | |
def idstII_pt(x, exp_vec): | |
N = x.shape[-1] | |
x_ = torch.flip(x, dims=(-1,)) | |
idct_x_ = idctII_pt(x_, exp_vec) | |
return idct_x_ * (-1)**torch.arange(N) | |
print(f'pytorch dct-II: {dctII_pt(x, exp_vec_1).cpu().numpy()}') | |
print(f'scipy dct-II: {scipy.fft.dct(x.cpu().numpy(), type=2)}\n') | |
print(f'pytorch dst-II: {dstII_pt(x, exp_vec_1).cpu().numpy()}') | |
print(f'scipy dst-II: {scipy.fft.dst(x.cpu().numpy(), type=2)}\n') | |
print('x :', x.cpu().numpy()) | |
print('pytorch idctII(dctII(x)):', idctII_pt(dctII_pt(x, exp_vec_1), exp_vec_2).cpu().numpy()) | |
print('pytorch idstII(dstII(x)):', idstII_pt(dstII_pt(x, exp_vec_1), exp_vec_2).cpu().numpy()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment