Last active
February 29, 2024 14:37
-
-
Save bridgesign/f421f69ad4a3858430e5e235bccde8c6 to your computer and use it in GitHub Desktop.
BiCGSTAB or BCGSTAB Pytorch implementation GPU support
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import warnings | |
class BiCGSTAB(): | |
""" | |
This is a pytorch implementation of BiCGSTAB or BCGSTAB, a stable version | |
of the CGD method, published first by Van Der Vrost. | |
For solving ``Ax = b`` system. | |
Example: | |
solver = BiCGSTAB(Ax_gen) | |
solver.solve(b, x=intial_x, tol=1e-10, atol=1e-16) | |
""" | |
def __init__(self, Ax_gen, device='cuda'): | |
""" | |
Ax_gen: A function that takes a 1-D tensor x and output Ax | |
Note: This structure is follwed as it may not be computationally | |
efficient to compute A explicitly. | |
""" | |
self.Ax_gen = Ax_gen | |
self.device = device | |
def init_params(self, b, x=None, nsteps=None, tol=1e-10, atol=1e-16): | |
""" | |
b: The R.H.S of the system. 1-D tensor | |
nsteps: Number of steps of calculation | |
tol: Tolerance such that if ||r||^2 < tol * ||b||^2 then converged | |
atol: Tolernace such that if ||r||^2 < atol then converged | |
""" | |
self.b = b.clone().detach() | |
self.x = torch.zeros(b.shape[0], device=self.device) if x is None else x | |
self.residual_tol = tol * torch.vdot(self.b, self.b).real | |
self.atol = torch.tensor(atol, device=self.device) | |
self.nsteps = b.shape[0] if nsteps is None else nsteps | |
self.status, self.r = self.check_convergence(self.x) | |
self.rho = torch.tensor(1, device=self.device) | |
self.alpha = torch.tensor(1, device=self.device) | |
self.omega = torch.tensor(1, device=self.device) | |
self.v = torch.zeros(b.shape[0], device=self.device) | |
self.p = torch.zeros(b.shape[0], device=self.device) | |
self.r_hat = self.r.clone().detach() | |
def check_convergence(self, x): | |
r = self.b - self.Ax_gen(x) | |
rdotr = torch.vdot(r,r).real | |
if rdotr < self.residual_tol or rdotr < self.atol: | |
return True, r | |
else: | |
return False, r | |
def step(self): | |
rho = torch.dot(self.r, self.r_hat) # rho_i <- <r0, r^> | |
beta = (rho/self.rho)*(self.alpha/self.omega) # beta <- (rho_i/rho_{i-1}) x (alpha/omega_{i-1}) | |
self.rho = rho # rho_{i-1} <- rho_i replaced self value | |
self.p = self.r + beta*(self.p - self.omega*self.v) # p_i <- r_{i-1} + beta x (p_{i-1} - w_{i-1} v_{i-1}) replaced p self value | |
self.v = self.Ax_gen(self.p) # v_i <- Ap_i | |
self.alpha = self.rho/torch.dot(self.r_hat, self.v) # alpha <- rho_i/<r^, v_i> | |
s = self.r - self.alpha*self.v # s <- r_{i-1} - alpha v_i | |
t = self.Ax_gen(s) # t <- As | |
self.omega = torch.dot(t, s)/torch.dot(t, t) # w_i <- <t, s>/<t, t> | |
self.x = self.x + self.alpha*self.p + self.omega*s # x_i <- x_{i-1} + alpha p + w_i s | |
status, res = self.check_convergence(self.x) | |
if status: | |
return True | |
else: | |
self.r = s - self.omega*t # r_i <- s - w_i t | |
return False | |
def solve(self, *args, **kwargs): | |
""" | |
Method to find the solution. | |
Returns the final answer of x | |
""" | |
self.init_params(*args, **kwargs) | |
if self.status: | |
return self.x | |
while self.nsteps: | |
s = self.step() | |
if s: | |
return self.x | |
if self.rho == 0: | |
break | |
self.nsteps-=1 | |
warnings.warn('Convergence has failed :(') | |
return self.x |
Author
bridgesign
commented
Mar 16, 2023
via email
•
There were some changes that were required. I have updated the code so that it can work with complex systems. Below is a simple example code:
```python
A = torch.complex(torch.rand(5,5), torch.rand(5,5)) # Random A for testing
b = A.matmul(torch.complex(torch.rand(5),torch.rand(5))) # Creating a
random b for testing
x = torch.complex(torch.rand(5), torch.rand(5)) # Random init point.
Required for complex case!
Ax_gen = lambda x: torch.matmul(A, x)
solver = BiCGSTAB(Ax_gen)
sol = solver.solve(b,x, nsteps=100)
```
Hope this solves your problem.
…On Mon, Mar 13, 2023 at 9:16 AM Nazanin ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
Thanks, I' have tested it. How does it work if elements of matrix A and
rhs (b) are complex numbers?
On Wed, 8 Mar 2023 at 14:19, Tom Vercauteren ***@***.***>
wrote:
> ***@***.**** commented on this gist.
> ------------------------------
>
> If helpful, on my side, I eventualy decided to go for a different
> implementation as shown here:
>
>
https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/utils/bicgstab.py
>
> —
> Reply to this email directly, view it on GitHub
> <
https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4495448
>
> or unsubscribe
> <
https://github.com/notifications/unsubscribe-auth/AFZEWEEDSM7SCYTINQA3OVDW3CBO3BFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF
>
> .
> You are receiving this email because you commented on the thread.
>
> Triage notifications on the go with GitHub Mobile for iOS
> <
https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675
>
> or Android
> <
https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub
>
> .
>
>
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4500982>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AHQ3RRTQIW3U2N36BGZXKKDW35B5JBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF>
.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment