Skip to content

Instantly share code, notes, and snippets.

@mratsim
Created May 3, 2018 11:39
Show Gist options
  • Select an option

  • Save mratsim/f4b7cb8f352ea3c817b06f2794be83f3 to your computer and use it in GitHub Desktop.

Select an option

Save mratsim/f4b7cb8f352ea3c817b06f2794be83f3 to your computer and use it in GitHub Desktop.
Nim QR decomposition
import ../tensor/tensor, nimlapack
proc syev*(jobz: cstring; uplo: cstring; n: ptr cint; a: ptr cfloat; lda: ptr cint;
w: ptr cfloat; work: ptr cfloat; lwork: ptr cint; info: ptr cint) {.inline.}=
ssyev(jobz, uplo, n, a, lda,
w, work, lwork, info)
proc syev*(jobz: cstring; uplo: cstring; n: ptr cint; a: ptr cdouble; lda: ptr cint;
w: ptr cdouble; work: ptr cdouble; lwork: ptr cint; info: ptr cint) {.inline.}=
dsyev(jobz, uplo, n, a, lda,
w, work, lwork, info)
proc symeig*[T: SomeReal](a: Tensor[T]): tuple[eigenval, eigenvec: Tensor[T]] =
## Compute the eigenvalues and eigen vectors of a symmetric matrix
## Input:
## - A symmetric matrix
## Returns:
## - A tuple with:
## - The eigenvalues sorted from lowest to highest
## - The corresponding eigenvector
##
## Implementation based on QR decomposition
assert a.shape[0] == a.shape[1], "Input should be a symmetric matrix"
# TODO, support "symmetric matrices" with only the upper or lower part filled.
# (Obviously, upper in Fortran is lower in C ...)
# input is destroyed by LAPACK
result.eigenvec = a.clone(layout = colMajor)
# Locals
var
n, lda: cint = a.shape[0].cint
info: cint
wkopt: T
lwork: cint = -1
jobz: cstring = "V" # N or V (eigenval only or eigenval + eigen vec)
uplo: cstring = "U" # U or L (upper or lower, in ColMajor layout)
result.eigenval = newTensorUninit[T](a.shape[0])
let w = result.eigenval.get_data_ptr
let vec = result.eigenvec.get_data_ptr
# Query and allocate optimal workspace
syev(jobz, uplo, n.addr, vec, lda.addr, w, wkopt.addr, lwork.addr, info.addr)
lwork = wkopt.cint
var work = newSeq[T](lwork)
# Solve eigenproblem
syev(jobz, uplo, n.addr, vec, lda.addr, w, work[0].addr, lwork.addr, info.addr)
if info > 0:
# TODO, this should not be an exception, not converging is something that can happen and should
# not interrupt the program. Alternative. Fill the result with Inf?
raise newException(ValueError, "the algorithm for computing the SVD failed to converge")
if info < 0:
raise newException(ValueError, "Illegal parameter in linear square solver gelsd: " & $(-info))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment