Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save j20232/a4688ab8a68c08cee2f08f3223ea672c to your computer and use it in GitHub Desktop.
Save j20232/a4688ab8a68c08cee2f08f3223ea672c to your computer and use it in GitHub Desktop.
convergence_test_optimized_with_torch_and_np.py
import numpy as np
import matplotlib.pyplot as plt
import torch
def func(x):
return 1e-4 * ((x - 6) ** 3) + 1e-4 * ((x - 5) ** 4) + 1e-2 * ((np.sin(x * 0.1) - 3) ** 2)
def torch_func(x):
return 1e-4 * ((x - 6) ** 3) + 1e-4 * ((x - 5) ** 4) + 1e-2 * ((torch.sin(x * 0.1) - 3) ** 2)
class FiniteDiff(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
np_x = input.cpu().detach().numpy()
y = func(np_x) # numpy implementation
t = torch.from_numpy(y)
return t
@staticmethod
def backward(ctx, grad_output):
h = 1e-5
input, = ctx.saved_tensors
np_x = input.cpu().detach().numpy()
# finite difference (central differences)
grad = (func(np_x + h) - func(np_x)) / (2 * h) + (func(np_x) - func(np_x - h)) / (2 * h) # numpy implementation
grad = torch.from_numpy(grad)
return grad_output * grad
if __name__ == "__main__":
np_x = np.linspace(4, 10, 100)
np_y = func(np_x) ** 4 + func(np_x) ** 3
tgt = torch.tensor([0.0])
max_esr = 10
epochs = 10000
# auto diff: red
x = torch.tensor([0.0], requires_grad=True)
optimizer = torch.optim.Adam([x], lr=1e-1)
esr = 0
best_loss = 9999
loss_fn = torch.nn.MSELoss()
for e in range(epochs):
optimizer.zero_grad()
val = torch_func(x) ** 4 + torch_func(x) ** 3
loss = loss_fn(val, tgt)
loss.backward()
optimizer.step()
loss_val = loss.cpu().detach().numpy()
if best_loss > loss_val:
esr = 0
best_loss = loss_val
v = val.cpu().detach().numpy()
else:
esr += 1
if esr >max_esr:
break
ad_v = x.cpu().detach().numpy()
# finite diff: yellow
x = torch.tensor([0.0], requires_grad=True)
optimizer = torch.optim.Adam([x], lr=1e-1)
esr = 0
best_loss = 9999
loss_fn = torch.nn.MSELoss()
fd_func = FiniteDiff.apply
for e in range(epochs):
optimizer.zero_grad()
val = fd_func(x) ** 4 + fd_func(x) ** 3
loss = loss_fn(val, tgt)
loss.backward()
optimizer.step()
loss_val = loss.cpu().detach().numpy()
if best_loss > loss_val:
esr = 0
best_loss = loss_val
v = val.cpu().detach().numpy()
else:
esr += 1
if esr >max_esr:
break
fd_v = x.cpu().detach().numpy()
plt.title("Convergence test")
plt.plot(np_x, np_y)
plt.plot(ad_v, func(ad_v) ** 4 + func(ad_v) ** 3, marker="o", color="red", label="Automatic differentiation")
plt.plot(fd_v, func(fd_v) ** 4 + func(fd_v) ** 3, marker="*", color="yellow", label="Finite difference")
plt.legend()
plt.xlabel("x")
plt.ylabel("f(x)")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment