Skip to content

Instantly share code, notes, and snippets.

@bridgesign
Last active February 29, 2024 14:37
Show Gist options
  • Select an option

  • Save bridgesign/f421f69ad4a3858430e5e235bccde8c6 to your computer and use it in GitHub Desktop.

Select an option

Save bridgesign/f421f69ad4a3858430e5e235bccde8c6 to your computer and use it in GitHub Desktop.
BiCGSTAB or BCGSTAB Pytorch implementation GPU support
import torch
import warnings
class BiCGSTAB():
"""
This is a pytorch implementation of BiCGSTAB or BCGSTAB, a stable version
of the CGD method, published first by Van Der Vrost.
For solving ``Ax = b`` system.
Example:
solver = BiCGSTAB(Ax_gen)
solver.solve(b, x=intial_x, tol=1e-10, atol=1e-16)
"""
def __init__(self, Ax_gen, device='cuda'):
"""
Ax_gen: A function that takes a 1-D tensor x and output Ax
Note: This structure is follwed as it may not be computationally
efficient to compute A explicitly.
"""
self.Ax_gen = Ax_gen
self.device = device
def init_params(self, b, x=None, nsteps=None, tol=1e-10, atol=1e-16):
"""
b: The R.H.S of the system. 1-D tensor
nsteps: Number of steps of calculation
tol: Tolerance such that if ||r||^2 < tol * ||b||^2 then converged
atol: Tolernace such that if ||r||^2 < atol then converged
"""
self.b = b.clone().detach()
self.x = torch.zeros(b.shape[0], device=self.device) if x is None else x
self.residual_tol = tol * torch.vdot(self.b, self.b).real
self.atol = torch.tensor(atol, device=self.device)
self.nsteps = b.shape[0] if nsteps is None else nsteps
self.status, self.r = self.check_convergence(self.x)
self.rho = torch.tensor(1, device=self.device)
self.alpha = torch.tensor(1, device=self.device)
self.omega = torch.tensor(1, device=self.device)
self.v = torch.zeros(b.shape[0], device=self.device)
self.p = torch.zeros(b.shape[0], device=self.device)
self.r_hat = self.r.clone().detach()
def check_convergence(self, x):
r = self.b - self.Ax_gen(x)
rdotr = torch.vdot(r,r).real
if rdotr < self.residual_tol or rdotr < self.atol:
return True, r
else:
return False, r
def step(self):
rho = torch.dot(self.r, self.r_hat) # rho_i <- <r0, r^>
beta = (rho/self.rho)*(self.alpha/self.omega) # beta <- (rho_i/rho_{i-1}) x (alpha/omega_{i-1})
self.rho = rho # rho_{i-1} <- rho_i replaced self value
self.p = self.r + beta*(self.p - self.omega*self.v) # p_i <- r_{i-1} + beta x (p_{i-1} - w_{i-1} v_{i-1}) replaced p self value
self.v = self.Ax_gen(self.p) # v_i <- Ap_i
self.alpha = self.rho/torch.dot(self.r_hat, self.v) # alpha <- rho_i/<r^, v_i>
s = self.r - self.alpha*self.v # s <- r_{i-1} - alpha v_i
t = self.Ax_gen(s) # t <- As
self.omega = torch.dot(t, s)/torch.dot(t, t) # w_i <- <t, s>/<t, t>
self.x = self.x + self.alpha*self.p + self.omega*s # x_i <- x_{i-1} + alpha p + w_i s
status, res = self.check_convergence(self.x)
if status:
return True
else:
self.r = s - self.omega*t # r_i <- s - w_i t
return False
def solve(self, *args, **kwargs):
"""
Method to find the solution.
Returns the final answer of x
"""
self.init_params(*args, **kwargs)
if self.status:
return self.x
while self.nsteps:
s = self.step()
if s:
return self.x
if self.rho == 0:
break
self.nsteps-=1
warnings.warn('Convergence has failed :(')
return self.x
@Nailemre
Copy link
Copy Markdown

How can ı run this code ? Can you give an example ?

@bridgesign
Copy link
Copy Markdown
Author

bridgesign commented Nov 30, 2021

@Nailemre,

import torch
from BiCGSTAB import BiCGSTAB

A = torch.randn(3,3, device='cuda')
x = torch.randn(3, device='cuda')
# Starting point
x_int = x + 0.01*torch.randn(3, device='cuda')
b = torch.matmul(A, x)
Ax_gen = lambda x: torch.matmul(A, x)
solver = BiCGSTAB(Ax_gen)
print("Original Solution:",x)
print("BiCGSTAB Solution:", solver.solve(b, nsteps=11, x=x_int, tol=1e-3))

Depending on requirement, it might not be possible to calculate A explicitly, but possible to calculate the matrix-vector product with A. Hence, Ax_gen is used as a function that return the matrix-vector product for a given x.

@Nailemre
Copy link
Copy Markdown

@bridgesign
Thank you for your answer. When I run the example you gave, I get an error as follows. Any ideas what am I doing wrong?

line 94, in
Ax_gen = lambda x: torch.matmul(A, x)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument vec in method wrapper__mv)

@bridgesign
Copy link
Copy Markdown
Author

@Nailemre Updated the comment. BiCGSTAB faces convergence issues and hence requires good initialization to work

@tvercaut
Copy link
Copy Markdown

Thanks for sharing this. As per the description, step 1 here, it looks like

r = self.Ax_gen(x) - self.b

(see https://gist.github.com/bridgesign/f421f69ad4a3858430e5e235bccde8c6#file-bicgstab-py-L50)
should be replaced by

r = self.b - self.Ax_gen(x)

@bridgesign
Copy link
Copy Markdown
Author

@tvercaut Thanks! Updated the gist.

@bridgesign
Copy link
Copy Markdown
Author

bridgesign commented Mar 7, 2023 via email

@nazaninj
Copy link
Copy Markdown

nazaninj commented Mar 7, 2023

Thanks for your reply. I try to solve a Ax=b systems. A shape is (4, 8). b .shape is (4,1). I started with x with shape of (8,1) but it crashes in solver part.
here: self.p = self.r + beta * ( self.p - self.omega * self.v).
beacuse size of x is 8 but size of p is 4.

error is:
Ax_gen = lambda x: torch.matmul(A, x)
RuntimeError: size mismatch, got 4, 4x8,4

@nazaninj
Copy link
Copy Markdown

nazaninj commented Mar 7, 2023 via email

@bridgesign
Copy link
Copy Markdown
Author

bridgesign commented Mar 7, 2023 via email

@tvercaut
Copy link
Copy Markdown

tvercaut commented Mar 8, 2023

If helpful, on my side, I eventualy decided to go for a different implementation as shown here:
https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/utils/bicgstab.py

@nazaninj
Copy link
Copy Markdown

nazaninj commented Mar 13, 2023 via email

@bridgesign
Copy link
Copy Markdown
Author

bridgesign commented Mar 16, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment