Created
June 26, 2019 07:56
-
-
Save HudsonHuang/12ecce121871362b546271a01d775e8c to your computer and use it in GitHub Desktop.
DFT的梯度
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# FFT 的 一阶,二阶和n阶导数:https://math.mit.edu/~stevenj/fft-deriv.pdf | |
# DFT的导数:https://math.stackexchange.com/a/1658364/684858 | |
import tensorflow as tf | |
import numpy as np | |
import torch | |
from torch.autograd import gradcheck, Variable | |
# mag loss | |
def mag(x): | |
print("mag in:",x.shape) | |
x = tf.complex(x, tf.zeros(x.shape,dtype=tf.float64)) | |
cl = tf.fft(x) | |
mag = tf.abs(cl)**2 | |
print("mag out:",mag.shape) | |
return mag | |
def tesnorflow_check(): | |
tf.enable_eager_execution() | |
# Eager execution for tensorflow version mag loss | |
signal = np.random.random((4,1,256,1)) | |
# Check if gradient can be compute | |
tfe = tf.contrib.eager | |
grad = tfe.gradients_function(mag) | |
k = grad(signal)[0].numpy() | |
print(k) | |
def pytorch_check(): | |
inx = Variable(torch.randn(1,128,2).double(), requires_grad=True) | |
test = gradcheck(torch.fft,(inx,2)) | |
print(test) | |
if __name__=="__main__": | |
tesnorflow_check() | |
pytorch_check() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment