Last active
February 29, 2024 14:37
-
-
Save bridgesign/f421f69ad4a3858430e5e235bccde8c6 to your computer and use it in GitHub Desktop.
BiCGSTAB or BCGSTAB Pytorch implementation GPU support
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import warnings | |
class BiCGSTAB(): | |
""" | |
This is a pytorch implementation of BiCGSTAB or BCGSTAB, a stable version | |
of the CGD method, published first by Van Der Vrost. | |
For solving ``Ax = b`` system. | |
Example: | |
solver = BiCGSTAB(Ax_gen) | |
solver.solve(b, x=intial_x, tol=1e-10, atol=1e-16) | |
""" | |
def __init__(self, Ax_gen, device='cuda'): | |
""" | |
Ax_gen: A function that takes a 1-D tensor x and output Ax | |
Note: This structure is follwed as it may not be computationally | |
efficient to compute A explicitly. | |
""" | |
self.Ax_gen = Ax_gen | |
self.device = device | |
def init_params(self, b, x=None, nsteps=None, tol=1e-10, atol=1e-16): | |
""" | |
b: The R.H.S of the system. 1-D tensor | |
nsteps: Number of steps of calculation | |
tol: Tolerance such that if ||r||^2 < tol * ||b||^2 then converged | |
atol: Tolernace such that if ||r||^2 < atol then converged | |
""" | |
self.b = b.clone().detach() | |
self.x = torch.zeros(b.shape[0], device=self.device) if x is None else x | |
self.residual_tol = tol * torch.vdot(self.b, self.b).real | |
self.atol = torch.tensor(atol, device=self.device) | |
self.nsteps = b.shape[0] if nsteps is None else nsteps | |
self.status, self.r = self.check_convergence(self.x) | |
self.rho = torch.tensor(1, device=self.device) | |
self.alpha = torch.tensor(1, device=self.device) | |
self.omega = torch.tensor(1, device=self.device) | |
self.v = torch.zeros(b.shape[0], device=self.device) | |
self.p = torch.zeros(b.shape[0], device=self.device) | |
self.r_hat = self.r.clone().detach() | |
def check_convergence(self, x): | |
r = self.b - self.Ax_gen(x) | |
rdotr = torch.vdot(r,r).real | |
if rdotr < self.residual_tol or rdotr < self.atol: | |
return True, r | |
else: | |
return False, r | |
def step(self): | |
rho = torch.dot(self.r, self.r_hat) # rho_i <- <r0, r^> | |
beta = (rho/self.rho)*(self.alpha/self.omega) # beta <- (rho_i/rho_{i-1}) x (alpha/omega_{i-1}) | |
self.rho = rho # rho_{i-1} <- rho_i replaced self value | |
self.p = self.r + beta*(self.p - self.omega*self.v) # p_i <- r_{i-1} + beta x (p_{i-1} - w_{i-1} v_{i-1}) replaced p self value | |
self.v = self.Ax_gen(self.p) # v_i <- Ap_i | |
self.alpha = self.rho/torch.dot(self.r_hat, self.v) # alpha <- rho_i/<r^, v_i> | |
s = self.r - self.alpha*self.v # s <- r_{i-1} - alpha v_i | |
t = self.Ax_gen(s) # t <- As | |
self.omega = torch.dot(t, s)/torch.dot(t, t) # w_i <- <t, s>/<t, t> | |
self.x = self.x + self.alpha*self.p + self.omega*s # x_i <- x_{i-1} + alpha p + w_i s | |
status, res = self.check_convergence(self.x) | |
if status: | |
return True | |
else: | |
self.r = s - self.omega*t # r_i <- s - w_i t | |
return False | |
def solve(self, *args, **kwargs): | |
""" | |
Method to find the solution. | |
Returns the final answer of x | |
""" | |
self.init_params(*args, **kwargs) | |
if self.status: | |
return self.x | |
while self.nsteps: | |
s = self.step() | |
if s: | |
return self.x | |
if self.rho == 0: | |
break | |
self.nsteps-=1 | |
warnings.warn('Convergence has failed :(') | |
return self.x |
@nazaninj, if the size of x is not the same as b, then create an
initial tensor x of the correct size and pass it to the solver.
```
x = torch.tensor([ some initial x estimate)]
solver.init_params(b, x=x)
```
…On Tue, Mar 7, 2023 at 8:07 AM Nazanin ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
Hi, I tried to run this code for an asymmetric matrix but it seems b and x
must have the same shape. Is it possible to fix it?
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4494409>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AHQ3RRUTHVEFNCP2NQBKCMTW25MLXBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF>
.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
Thanks for your reply. I try to solve a Ax=b systems. A shape is (4, 8). b .shape is (4,1). I started with x with shape of (8,1) but it crashes in solver part.
here: self.p = self.r + beta * ( self.p - self.omega * self.v).
beacuse size of x is 8 but size of p is 4.
error is:
Ax_gen = lambda x: torch.matmul(A, x)
RuntimeError: size mismatch, got 4, 4x8,4
Thanks for your reply.
I try to solve a Ax=b systems.
A shape is (4, 8), b .shape is (4,1).
I started with x with shape of (8,1) but it crashes in solver , here: self.p
= self.r + beta * (
self.p - self.omega * self.v)
because the size of p here is 4, but must be 8.
I get this error: RuntimeError: size mismatch, got 4, 4x8,4
…On Tue, 7 Mar 2023 at 19:37, Rohan Patil ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
@nazaninj, if the size of x is not the same as b, then create an
initial tensor x of the correct size and pass it to the solver.
```
x = torch.tensor([ some initial x estimate)]
solver.init_params(b, x=x)
```
On Tue, Mar 7, 2023 at 8:07 AM Nazanin ***@***.***> wrote:
> ***@***.**** commented on this gist.
> ------------------------------
>
> Hi, I tried to run this code for an asymmetric matrix but it seems b and
x
> must have the same shape. Is it possible to fix it?
>
> —
> Reply to this email directly, view it on GitHub
> <
https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4494409
>
> or unsubscribe
> <
https://github.com/notifications/unsubscribe-auth/AHQ3RRUTHVEFNCP2NQBKCMTW25MLXBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF
>
> .
> You are receiving this email because you authored the thread.
>
> Triage notifications on the go with GitHub Mobile for iOS
> <
https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675
>
> or Android
> <
https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub
>
> .
>
>
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4494525>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFZEWEHXZ2MBXPGY6IBV3ATW2557NBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF>
.
You are receiving this email because you were mentioned.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
Sorry I forgot that it is for square matrices only. One possible way is to
make A an (8,8) matrix and b an (8,1) vector with extra entries to be 0. I
don't know if there is an algorithm for non-square matrices. If you know a
reference for one, please let me know. I will add the feature if it's there.
…On Tue, Mar 7, 2023 at 1:49 PM Nazanin ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
Thanks for your reply.
I try to solve a Ax=b systems.
A shape is (4, 8), b .shape is (4,1).
I started with x with shape of (8,1) but it crashes in solver , here:
self.p
= self.r + beta * (
self.p - self.omega * self.v)
because the size of p here is 4, but must be 8.
I get this error: RuntimeError: size mismatch, got 4, 4x8,4
On Tue, 7 Mar 2023 at 19:37, Rohan Patil ***@***.***> wrote:
> ***@***.**** commented on this gist.
> ------------------------------
> @nazaninj, if the size of x is not the same as b, then create an
> initial tensor x of the correct size and pass it to the solver.
> ```
> x = torch.tensor([ some initial x estimate)]
> solver.init_params(b, x=x)
> ```
>
> On Tue, Mar 7, 2023 at 8:07 AM Nazanin ***@***.***> wrote:
>
> > ***@***.**** commented on this gist.
> > ------------------------------
> >
> > Hi, I tried to run this code for an asymmetric matrix but it seems b
and
> x
> > must have the same shape. Is it possible to fix it?
> >
> > —
> > Reply to this email directly, view it on GitHub
> > <
>
https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4494409
> >
> > or unsubscribe
> > <
>
https://github.com/notifications/unsubscribe-auth/AHQ3RRUTHVEFNCP2NQBKCMTW25MLXBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF
> >
> > .
> > You are receiving this email because you authored the thread.
> >
> > Triage notifications on the go with GitHub Mobile for iOS
> > <
>
https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675
> >
> > or Android
> > <
>
https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub
> >
> > .
> >
> >
>
> —
> Reply to this email directly, view it on GitHub
> <
https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4494525
>
> or unsubscribe
> <
https://github.com/notifications/unsubscribe-auth/AFZEWEHXZ2MBXPGY6IBV3ATW2557NBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF
>
> .
> You are receiving this email because you were mentioned.
>
> Triage notifications on the go with GitHub Mobile for iOS
> <
https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675
>
> or Android
> <
https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub
>
> .
>
>
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4494765>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AHQ3RRRSOBHMDIW7F5BSZLDW26UP5BFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF>
.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
If helpful, on my side, I eventualy decided to go for a different implementation as shown here:
https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/utils/bicgstab.py
Thanks, I' have tested it. How does it work if elements of matrix A and
rhs (b) are complex numbers?
…On Wed, 8 Mar 2023 at 14:19, Tom Vercauteren ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
If helpful, on my side, I eventualy decided to go for a different
implementation as shown here:
https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/utils/bicgstab.py
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4495448>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFZEWEEDSM7SCYTINQA3OVDW3CBO3BFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF>
.
You are receiving this email because you commented on the thread.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
There were some changes that were required. I have updated the code so that it can work with complex systems. Below is a simple example code:
```python
A = torch.complex(torch.rand(5,5), torch.rand(5,5)) # Random A for testing
b = A.matmul(torch.complex(torch.rand(5),torch.rand(5))) # Creating a
random b for testing
x = torch.complex(torch.rand(5), torch.rand(5)) # Random init point.
Required for complex case!
Ax_gen = lambda x: torch.matmul(A, x)
solver = BiCGSTAB(Ax_gen)
sol = solver.solve(b,x, nsteps=100)
```
Hope this solves your problem.
…On Mon, Mar 13, 2023 at 9:16 AM Nazanin ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
Thanks, I' have tested it. How does it work if elements of matrix A and
rhs (b) are complex numbers?
On Wed, 8 Mar 2023 at 14:19, Tom Vercauteren ***@***.***>
wrote:
> ***@***.**** commented on this gist.
> ------------------------------
>
> If helpful, on my side, I eventualy decided to go for a different
> implementation as shown here:
>
>
https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/utils/bicgstab.py
>
> —
> Reply to this email directly, view it on GitHub
> <
https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4495448
>
> or unsubscribe
> <
https://github.com/notifications/unsubscribe-auth/AFZEWEEDSM7SCYTINQA3OVDW3CBO3BFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF
>
> .
> You are receiving this email because you commented on the thread.
>
> Triage notifications on the go with GitHub Mobile for iOS
> <
https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675
>
> or Android
> <
https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub
>
> .
>
>
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/f421f69ad4a3858430e5e235bccde8c6#gistcomment-4500982>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AHQ3RRTQIW3U2N36BGZXKKDW35B5JBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTAOJXHAYTQNJXU52HE2LHM5SXFJTDOJSWC5DF>
.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@tvercaut Thanks! Updated the gist.