Skip to content

Instantly share code, notes, and snippets.

@lotz84
Created November 29, 2016 15:30
Show Gist options
  • Save lotz84/93be072cd673a06f0fdd2876fe98a552 to your computer and use it in GitHub Desktop.
Save lotz84/93be072cd673a06f0fdd2876fe98a552 to your computer and use it in GitHub Desktop.
ISTA, Haskell implementation
name: ista
version: 0.1.0.0
build-type: Simple
cabal-version: >=1.10
executable app
main-is: ista.hs
ghc-options: -threaded -rtsopts -with-rtsopts=-N
build-depends: base
, mwc-random
, hmatrix
, plots
, diagrams-lib
, diagrams-core
, diagrams-rasterific
default-language: Haskell2010
module Main where
import Control.Monad
import Numeric.LinearAlgebra
import System.Random.MWC
import Plots
import Diagrams.Prelude hiding (scale, (<>), Vector)
import Diagrams.Backend.Rasterific.CmdLine
x :: Vector Double
x = vector [0, 0, 2, 0, 3]
genMat :: GenIO -> IO (Matrix Double)
genMat gen =
let (n, m) = (3, 5)
in (n >< m) <$> replicateM (n*m) (uniformR (-3, 3) gen)
softTh :: Double -> Double -> Double
softTh l y
| y < -l = y + l
| l < y = y - l
| otherwise = 0
ista :: Double -- lambda
-> Vector Double -- y
-> Matrix Double -- A
-> Vector Double -- x_t
-> Vector Double -- x_t+1
ista lambda y mat x =
let l = norm_2 (tr mat <> mat) / lambda
v = x + (1/(l*lambda)) `scale` tr mat #> (y - mat #> x)
in cmap (softTh (1/l)) v
fista :: Double -- lambda
-> Vector Double -- y
-> Matrix Double -- A
-> (Vector Double, Double) -- x_t, b_t
-> (Vector Double, Double) -- x_t+1, b_t+1
fista lambda y mat (x, b) =
let l = norm_2 (tr mat <> mat) / lambda
v = x + (1/(l*lambda)) `scale` tr mat #> (y - mat #> x)
x' = cmap (softTh (1/l)) v
b' = 0.5 * (1 + sqrt (1 + 4 * b^2))
in (x' + ((b-1)/b') `scale` (x' - x), b')
main :: IO ()
main = do
gen <- createSystemRandom
mat <- genMat gen
let y = mat #> x
let experience1 0 x es = pure (x, reverse es)
experience1 n x es = do
let x' = ista 3.0 y mat x
e = norm_2 (y - mat #> x') + norm_1 x'
experience1 (n-1) x' (e:es)
let experience2 0 (x, b) es = pure (x, reverse es)
experience2 n (x, b) es = do
let (x', b') = fista 3.0 y mat (x, b)
e = norm_2 (y - mat #> x') + norm_1 x'
experience2 (n-1) (x', b') (e:es)
putStrLn $ "y: " ++ dispf 3 (asRow y)
putStrLn $ "A: " ++ dispf 3 mat
(result, errors) <- experience1 50 (tr mat #> y) []
-- (result, errors) <- experience2 30 (tr mat #> y, 0) []
putStrLn $ "Result: " ++ dispf 3 (asRow result)
r2AxisMain (myaxis errors)
where
myaxis :: [Double] -> Axis B V2 Double
myaxis es = r2Axis &~ do
linePlot' $ zip [1..] es
resolver: lts-7.11
packages:
- '.'
extra-deps:
- plots-0.1.0.2
flags: {}
extra-package-dbs: []
@lotz84
Copy link
Author

lotz84 commented Nov 29, 2016

$ stack build && stack exec app -- -o output.png
y: 1x3
10.281  -7.006  6.194

A: 3x5
-2.372  -1.480  1.717  -1.749   2.283
-1.872  -0.261  0.697   0.187  -2.800
-0.572   2.219  2.902   1.403   0.130

Result: 1x5
0.000  0.000  1.785  0.000  2.809

output

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment