-
-
Save bridgesign/f421f69ad4a3858430e5e235bccde8c6 to your computer and use it in GitHub Desktop.
import torch | |
import warnings | |
class BiCGSTAB(): | |
""" | |
This is a pytorch implementation of BiCGSTAB or BCGSTAB, a stable version | |
of the CGD method, published first by Van Der Vrost. | |
For solving ``Ax = b`` system. | |
Example: | |
solver = BiCGSTAB(Ax_gen) | |
solver.solve(b, x=intial_x, tol=1e-10, atol=1e-16) | |
""" | |
def __init__(self, Ax_gen, device='cuda'): | |
""" | |
Ax_gen: A function that takes a 1-D tensor x and output Ax | |
Note: This structure is follwed as it may not be computationally | |
efficient to compute A explicitly. | |
""" | |
self.Ax_gen = Ax_gen | |
self.device = device | |
def init_params(self, b, x=None, nsteps=None, tol=1e-10, atol=1e-16): | |
""" | |
b: The R.H.S of the system. 1-D tensor | |
nsteps: Number of steps of calculation | |
tol: Tolerance such that if ||r||^2 < tol * ||b||^2 then converged | |
atol: Tolernace such that if ||r||^2 < atol then converged | |
""" | |
self.b = b.clone().detach() | |
self.x = torch.zeros(b.shape[0], device=self.device) if x is None else x | |
self.residual_tol = tol * torch.vdot(self.b, self.b).real | |
self.atol = torch.tensor(atol, device=self.device) | |
self.nsteps = b.shape[0] if nsteps is None else nsteps | |
self.status, self.r = self.check_convergence(self.x) | |
self.rho = torch.tensor(1, device=self.device) | |
self.alpha = torch.tensor(1, device=self.device) | |
self.omega = torch.tensor(1, device=self.device) | |
self.v = torch.zeros(b.shape[0], device=self.device) | |
self.p = torch.zeros(b.shape[0], device=self.device) | |
self.r_hat = self.r.clone().detach() | |
def check_convergence(self, x): | |
r = self.b - self.Ax_gen(x) | |
rdotr = torch.vdot(r,r).real | |
if rdotr < self.residual_tol or rdotr < self.atol: | |
return True, r | |
else: | |
return False, r | |
def step(self): | |
rho = torch.dot(self.r, self.r_hat) # rho_i <- <r0, r^> | |
beta = (rho/self.rho)*(self.alpha/self.omega) # beta <- (rho_i/rho_{i-1}) x (alpha/omega_{i-1}) | |
self.rho = rho # rho_{i-1} <- rho_i replaced self value | |
self.p = self.r + beta*(self.p - self.omega*self.v) # p_i <- r_{i-1} + beta x (p_{i-1} - w_{i-1} v_{i-1}) replaced p self value | |
self.v = self.Ax_gen(self.p) # v_i <- Ap_i | |
self.alpha = self.rho/torch.dot(self.r_hat, self.v) # alpha <- rho_i/<r^, v_i> | |
s = self.r - self.alpha*self.v # s <- r_{i-1} - alpha v_i | |
t = self.Ax_gen(s) # t <- As | |
self.omega = torch.dot(t, s)/torch.dot(t, t) # w_i <- <t, s>/<t, t> | |
self.x = self.x + self.alpha*self.p + self.omega*s # x_i <- x_{i-1} + alpha p + w_i s | |
status, res = self.check_convergence(self.x) | |
if status: | |
return True | |
else: | |
self.r = s - self.omega*t # r_i <- s - w_i t | |
return False | |
def solve(self, *args, **kwargs): | |
""" | |
Method to find the solution. | |
Returns the final answer of x | |
""" | |
self.init_params(*args, **kwargs) | |
if self.status: | |
return self.x | |
while self.nsteps: | |
s = self.step() | |
if s: | |
return self.x | |
if self.rho == 0: | |
break | |
self.nsteps-=1 | |
warnings.warn('Convergence has failed :(') | |
return self.x |
import torch
from BiCGSTAB import BiCGSTAB
A = torch.randn(3,3, device='cuda')
x = torch.randn(3, device='cuda')
# Starting point
x_int = x + 0.01*torch.randn(3, device='cuda')
b = torch.matmul(A, x)
Ax_gen = lambda x: torch.matmul(A, x)
solver = BiCGSTAB(Ax_gen)
print("Original Solution:",x)
print("BiCGSTAB Solution:", solver.solve(b, nsteps=11, x=x_int, tol=1e-3))
Depending on requirement, it might not be possible to calculate A explicitly, but possible to calculate the matrix-vector product with A. Hence, Ax_gen is used as a function that return the matrix-vector product for a given x.
@bridgesign
Thank you for your answer. When I run the example you gave, I get an error as follows. Any ideas what am I doing wrong?
line 94, in
Ax_gen = lambda x: torch.matmul(A, x)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument vec in method wrapper__mv)
@Nailemre Updated the comment. BiCGSTAB faces convergence issues and hence requires good initialization to work
Thanks for sharing this. As per the description, step 1 here, it looks like
r = self.Ax_gen(x) - self.b
(see https://gist.github.com/bridgesign/f421f69ad4a3858430e5e235bccde8c6#file-bicgstab-py-L50)
should be replaced by
r = self.b - self.Ax_gen(x)
@tvercaut Thanks! Updated the gist.
Thanks for your reply. I try to solve a Ax=b systems. A shape is (4, 8). b .shape is (4,1). I started with x with shape of (8,1) but it crashes in solver part.
here: self.p = self.r + beta * ( self.p - self.omega * self.v).
beacuse size of x is 8 but size of p is 4.
error is:
Ax_gen = lambda x: torch.matmul(A, x)
RuntimeError: size mismatch, got 4, 4x8,4
If helpful, on my side, I eventualy decided to go for a different implementation as shown here:
https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/utils/bicgstab.py
How can ı run this code ? Can you give an example ?