Last active
May 23, 2021 12:16
-
-
Save oxinabox/ad054535b8f84cf060d3ac35af77c64f to your computer and use it in GitHub Desktop.
Why are so many implementations of the Thomas algorithm wrong?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra | |
# Wikipedia non-preserving version (transcribed from VB) | |
# https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm | |
# this one is wrong (found this in use in the wild π’) | |
function thomas_algorithm!(a, b, c, r, ::Val{1}) | |
n = length(b) | |
for i in 2:(n-1) | |
m = a[i]/b[i-1]; | |
b[i] = b[i] - m * c[i - 1]; | |
r[i] = r[i] - m*r[i-1]; | |
end | |
x = similar(b) | |
x[end] = r[end]/b[end]; | |
for i in (n-1):-1:1 | |
x[i] = (r[i] - c[i] * x[i+1]) / b[i] | |
end | |
return x | |
end | |
# wikipdia out of place version | |
# this one is right | |
function thomas_algorithm!(a, b, c, d, ::Val{2}) | |
n = length(b) | |
dp = similar(d) | |
cp = similar(c) | |
dp[1] = d[1]/b[1] | |
cp[1] = c[1]/b[1] | |
for i in 2:n | |
r = 1/(b[i] - a[i]*cp[i-1]) | |
dp[i] = r*(d[i] - a[i]*dp[i-1]) | |
cp[i] = r * c[i] | |
end | |
x = similar(d) | |
x[end] = dp[end] | |
for i in (n-1):-1:1 | |
x[i] = dp[i] - cp[i]*x[i+1] | |
end | |
return x | |
end | |
# Algorithm 1 from https://people.maths.ox.ac.uk/gilesm/files/toms_16b.pdf | |
# this one is also wrong | |
function thomas_algorithm!(a, b, c, d, ::Val{3}) | |
n = length(b) | |
dp = similar(d) | |
cp = similar(c) | |
dp[1] = d[1]/b[1] | |
cp[1] = c[1]/b[1] | |
for i in 2:n | |
r = 1/(b[i] - a[i]*c[i-1]) | |
dp[i] = r*(d[i] - a[i]*d[i-1]) | |
cp[i] = r * c[i] | |
end | |
for i in (n-1):-1:1 | |
d[i] = dp[i] - cp[i]*d[i+1] | |
end | |
return d | |
end | |
# from torchcde | |
# https://github.com/patrick-kidger/torchcde/blob/d3ebdd554f138a07832e31cacca7bc0944d2004e/torchcde/misc.py#L13 | |
# Correct as long as not padded | |
function thomas_algorithm!(A_lower, A_diagonal, A_upper, b, ::Val{4}) | |
channels = length(A_diagonal) | |
new_b = similar(b) | |
new_A_diagonal = similar(A_diagonal) | |
outs = similar(A_diagonal) | |
new_b[1] = b[1] | |
new_A_diagonal[1] = A_diagonal[1] | |
for i in 2:channels | |
w = A_lower[i-1]/new_A_diagonal[i-1]; | |
new_A_diagonal[i] = A_diagonal[i] - w * A_upper[i - 1]; | |
new_b[i] = b[i] - w*new_b[i-1]; | |
end | |
outs[end] = new_b[end]/new_A_diagonal[end]; | |
for i in (channels-1):-1:1 | |
outs[i] = (new_b[i] - A_upper[i] * outs[i+1]) / new_A_diagonal[i] | |
end | |
return outs | |
end | |
################# | |
# No padding | |
function thomas_algorithm(lhs::Tridiagonal, r, ver::Val{4}) | |
a = diag(lhs, -1) | |
b = diag(lhs) | |
c = diag(lhs, 1) | |
d = copy(r) | |
return thomas_algorithm!(a, b, c, d, ver) | |
end | |
# padding | |
function thomas_algorithm(lhs::Tridiagonal, r, ver::Union{Val{1}, Val{2}, Val{3}}) | |
a = [0; diag(lhs, -1)] | |
b = diag(lhs) | |
c = [diag(lhs, 1); 0] | |
d = copy(r) | |
return thomas_algorithm!(a, b, c, d, ver) | |
end | |
###################### | |
# Experiment 2x2 | |
lhs = Tridiagonal([2.0 1.0; 2.0 7.0]) | |
rhs = [1.800000007527517, -7.400000108059436] | |
@show lhs\rhs | |
@show thomas_algorithm(lhs, rhs, Val(1)) | |
@show thomas_algorithm(lhs, rhs, Val(2)) | |
@show thomas_algorithm(lhs, rhs, Val(3)) | |
@show thomas_algorithm(lhs, rhs, Val(4)) | |
@show lhs*(lhs\rhs) β rhs # true | |
@show lhs*(thomas_algorithm(lhs, rhs, Val(1))) β rhs # false | |
#@show lhs*(thomas_algorithm(lhs, rhs, Val(2))) β rhs # true | |
#@show lhs*(thomas_algorithm(lhs, rhs, Val(3))) β rhs # false | |
@show lhs*(thomas_algorithm(lhs, rhs, Val(4))) β rhs | |
##################### | |
# Experiment 3x3 | |
lhs = Tridiagonal([1., 2], [10., 20, 30], [1., 2]) | |
rhs = [11., 12, 13] | |
lhs\rhs | |
lhs*(lhs\rhs) | |
@show lhs\rhs | |
@show thomas_algorithm(lhs, rhs, Val(1)) | |
@show thomas_algorithm(lhs, rhs, Val(2)) | |
@show thomas_algorithm(lhs, rhs, Val(3)) | |
@show thomas_algorithm(lhs, rhs, Val(4)) | |
@show lhs*(lhs\rhs) β rhs # true | |
@show lhs*(thomas_algorithm(lhs, rhs, Val(1))) β rhs # false | |
@show lhs*(thomas_algorithm(lhs, rhs, Val(2))) β rhs # true | |
@show lhs*(thomas_algorithm(lhs, rhs, Val(3))) β rhs # false | |
@show lhs*(thomas_algorithm(lhs, rhs, Val(4))) β rhs # true |
Not yet. I am waiting to be told that I implemented them wrong.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Did you fix the Wikipedia ones?