Last active
January 31, 2025 18:57
-
-
Save jwscook/a24f72e00a50f131c534ecd1de05cf70 to your computer and use it in GitHub Desktop.
Right-Look LU decomposition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Base.Threads, LinearAlgebra, SparseArrays | |
using Random; Random.seed!(0) | |
using ChunkSplitters | |
using BlockArrays | |
using ThreadPinning | |
pinthreads(:cores) | |
using Polyester | |
lsolve!(A, L::AbstractSparseArray) = (A .= L \ A) # can't mutate L | |
lsolve!(A, L) = ldiv!(A, lu(L), A) # could use a work array here in lu! | |
function lsolve!(A, L, work) | |
W = view(work, 1:size(L, 1), 1:size(L, 2)) | |
copyto!(W, L) | |
luW = lu!(W, Val(true); check=false) # lu!(W, NoPivot(); check=false) calls generic lu! | |
ldiv!(luW, A) | |
end | |
rsolve!(A, U::AbstractSparseArray) = (A .= A / U) # can't mutate U | |
rsolve!(A, U) = rdiv!(A, lu(U)) # could use a work array here in lu! | |
function rsolve!(A, U, work) | |
W = view(work, 1:size(U, 1), 1:size(U, 2)) | |
copyto!(W, U)# W .= U | |
luW = lu!(W, Val(true); check=false) # lu!(W, NoPivot(); check=false) calls generic lu! | |
rdiv!(A, luW) | |
end | |
struct RightLookLU{T, M<:AbstractMatrix{T}} | |
A::M | |
ntiles::Int | |
rowindices::Vector{UnitRange{Int64}} | |
colindices::Vector{UnitRange{Int64}} | |
isempties::Matrix{Bool} | |
works::Vector{Matrix{T}} | |
end | |
function RightLookLU(A::AbstractMatrix, ntiles::Int) | |
rowindices = collect(chunks(1:size(A, 1); n=ntiles)) | |
colindices = collect(chunks(1:size(A, 2); n=ntiles)) | |
isempties = zeros(Bool, length(rowindices), length(colindices)) | |
A = BlockArray(A, [length(is) for is in rowindices], [length(js) for js in colindices]) | |
works = [similar(A, maximum(length(is) for is in rowindices), | |
maximum(length(js) for js in colindices)) for _ in 1:nthreads()] | |
return RightLookLU(A, ntiles, rowindices, colindices, isempties, works) | |
end | |
tile(A::RightLookLU{T, M}, i, j) where {T, M<:BlockArray{T}} = blocks(A.A)[i, j] | |
function tile(A::RightLookLU{T, M}, i, j) where {T, M} | |
is, js = A.rowindices[i], A.colindices[j] | |
return view(A.A, is, js) | |
end | |
function LinearAlgebra.lu!(RL::RightLookLU, A::AbstractMatrix) | |
tasks = Task[] | |
for (i, is) in enumerate(RL.rowindices), (j, js) in enumerate(RL.colindices) | |
RL.isempties[i, j] && continue # must have same sparsity pattern TODO check | |
push!(tasks, @spawn copyto!(tile(RL, i, j), view(A, is, js))) | |
end | |
wait.(tasks) | |
return lu!(RL) | |
end | |
function LinearAlgebra.lu!(RL::RightLookLU) | |
for level in 1:RL.ntiles | |
factorise!(RL, level) | |
end | |
return RL | |
end | |
function factorise!(A::RightLookLU, level) | |
All = subtractleft!(A, level, level) | |
Lll, Ull = lu!(All, NoPivot(); check=false) | |
lks = Vector{Tuple{Int,Int}}() | |
for k in level + 1:A.ntiles | |
push!(lks, (level, k)) | |
push!(lks, (k, level)) | |
end | |
@threads for it in lks | |
lookright!(A, it..., Lll, Ull) | |
end | |
return A | |
end | |
function lookright!(A::RightLookLU, i, j, L, U) | |
@assert i != j | |
Aij = subtractleft!(A, i, j) | |
A.isempties[i, j] && return | |
i > j && rsolve!(Aij, U, A.works[threadid()]) | |
i < j && lsolve!(Aij, L, A.works[threadid()]) | |
A.isempties[i, j] = iszero(Aij) | |
end | |
function _mul!(A, L, U) | |
A .-= L * U | |
end | |
function _mul!(A::Matrix{T}, L, U) where T | |
BLAS.gemm!('N', 'N', -one(T), L, U, one(T), A) # gemm!(tA, tB, alpha, A, B, beta, C) # Update C as alpha*A*B + beta*C or | |
end | |
function _mul!(A::SparseMatrixCSC, L, U) | |
mul!(A, L, U, -1, true) #mul!(C, A, B, α, β); C == A * B * α + C_original * β | |
end | |
function subtractleft!(A::RightLookLU, i, j) | |
Aij = tile(A, i, j) | |
for k in 1:min(i, j) - 1 | |
(A.isempties[i, k] || A.isempties[k, j]) && continue | |
Lik = tile(A, i, k) | |
Ukj = tile(A, k, j) | |
_mul!(Aij, Lik, Ukj) | |
end | |
A.isempties[i, j] = iszero(Aij) | |
return Aij | |
end | |
function test_matrix(T=Float64; ntiles=6, tilesize=5, overlap=1) | |
n = ntiles * tilesize - (tilesize - overlap - 1) * (ntiles - 1) | |
mat = zeros(T, n,n) | |
for i in 1:ntiles | |
j = (i - 1) * tilesize - (tilesize - overlap - 1) * (i - 1) + 1 | |
mat[j:j + tilesize - 1, j:j + tilesize - 1] .= rand(T, tilesize, tilesize) | |
end | |
return sparse(mat) | |
end | |
using Test | |
function mybenchmark(f::F, args...; n=3) where F | |
return minimum([(newargs = deepcopy.(args); @elapsed f(newargs...)) for _ in 1:n]) | |
end | |
using StatProfilerHTML | |
function foo() | |
s = test_matrix(ComplexF64; ntiles=32, tilesize=1024, overlap=16) | |
s .+= 10 * I(size(s, 1)) | |
sc = deepcopy(Matrix(s)) | |
rl = RightLookLU(sc, 16) | |
lu!(rl) | |
lu!(rl, sc) | |
@profilehtml [lu!(rl, sc) for _ in 1:10] | |
#return | |
@testset "rightlooklu!" begin | |
for n in 2 .^ (4, 6, 8, 10), lutiles in (2, 4, 8, 16) | |
A = rand(ComplexF64, n, n) | |
L, U = lu(A, NoPivot()) | |
Ac = deepcopy(A) | |
LR = RightLookLU(Ac, lutiles) | |
lu!(LR) | |
L1, U1 = tril(LR.A, -1) .+ I(n), triu(LR.A) | |
@test L1 * U1 ≈ A | |
t1 = mybenchmark((x...)->lu!(RightLookLU(x...)), A, lutiles; n=10) | |
t2 = mybenchmark(x->lu!(x, NoPivot()), A; n=10) | |
@show n, lutiles, t1 / t2 | |
end | |
for ntiles in (8, 16), tilesize in (64, 128, 256, 512), overlap in (4,) | |
s = test_matrix(ComplexF64; ntiles=ntiles, tilesize=tilesize, overlap=overlap) | |
s .+= 10 * I(size(s, 1)) | |
lu_s = lu(deepcopy(s)) | |
tb1 = mybenchmark(x->lu!(x, NoPivot()), s; n=20) | |
tb2 = mybenchmark(x->lu!(lu_s, x), s; n=20) | |
for lutiles in min.(size(s, 1), (8, 10, 12, 14, 16, 20)) | |
sc = deepcopy(s) | |
LR = RightLookLU(sc, lutiles) | |
lu!(LR) | |
L2, U2 = tril(LR.A, -1) .+ I(size(LR.A, 1)), triu(LR.A) | |
@test L2 * U2 ≈ s | |
ta1 = mybenchmark((x...)->lu!(RightLookLU(x...)), s, lutiles; n=20) | |
ta2 = mybenchmark(x->lu!(LR, x), s; n=20) | |
#lutilesize = size(s, 1) ÷ lutiles | |
@show ntiles, tilesize, overlap, lutiles, ta1 / tb1, ta2 / tb2 | |
end | |
end | |
end | |
end | |
foo() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The takeaway here is that UMFPACK is amazing.