Skip to content

Instantly share code, notes, and snippets.

@elumixor
Last active May 25, 2020 11:02
Show Gist options
  • Save elumixor/a3131a2fb628f79d4d11d7de266b8eb3 to your computer and use it in GitHub Desktop.
Save elumixor/a3131a2fb628f79d4d11d7de266b8eb3 to your computer and use it in GitHub Desktop.
def conjugate_gradient(A, b, delta=0., max_iterations=float('inf')):
x = torch.zeros_like(b)
r = b.clone()
p = b.clone()
i = 0
while i < max_iterations:
AVP = A(p)
dot_old = r @ r
alpha = dot_old / (p @ AVP)
x_new = x + alpha * p
if (x - x_new).norm() <= delta:
return x_new
i += 1
r = r - alpha * AVP
beta = (r @ r) / dot_old
p = r + beta * p
x = x_new
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment