Created
November 29, 2016 15:30
-
-
Save lotz84/93be072cd673a06f0fdd2876fe98a552 to your computer and use it in GitHub Desktop.
ISTA, Haskell implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: ista | |
version: 0.1.0.0 | |
build-type: Simple | |
cabal-version: >=1.10 | |
executable app | |
main-is: ista.hs | |
ghc-options: -threaded -rtsopts -with-rtsopts=-N | |
build-depends: base | |
, mwc-random | |
, hmatrix | |
, plots | |
, diagrams-lib | |
, diagrams-core | |
, diagrams-rasterific | |
default-language: Haskell2010 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Main where | |
import Control.Monad | |
import Numeric.LinearAlgebra | |
import System.Random.MWC | |
import Plots | |
import Diagrams.Prelude hiding (scale, (<>), Vector) | |
import Diagrams.Backend.Rasterific.CmdLine | |
x :: Vector Double | |
x = vector [0, 0, 2, 0, 3] | |
genMat :: GenIO -> IO (Matrix Double) | |
genMat gen = | |
let (n, m) = (3, 5) | |
in (n >< m) <$> replicateM (n*m) (uniformR (-3, 3) gen) | |
softTh :: Double -> Double -> Double | |
softTh l y | |
| y < -l = y + l | |
| l < y = y - l | |
| otherwise = 0 | |
ista :: Double -- lambda | |
-> Vector Double -- y | |
-> Matrix Double -- A | |
-> Vector Double -- x_t | |
-> Vector Double -- x_t+1 | |
ista lambda y mat x = | |
let l = norm_2 (tr mat <> mat) / lambda | |
v = x + (1/(l*lambda)) `scale` tr mat #> (y - mat #> x) | |
in cmap (softTh (1/l)) v | |
fista :: Double -- lambda | |
-> Vector Double -- y | |
-> Matrix Double -- A | |
-> (Vector Double, Double) -- x_t, b_t | |
-> (Vector Double, Double) -- x_t+1, b_t+1 | |
fista lambda y mat (x, b) = | |
let l = norm_2 (tr mat <> mat) / lambda | |
v = x + (1/(l*lambda)) `scale` tr mat #> (y - mat #> x) | |
x' = cmap (softTh (1/l)) v | |
b' = 0.5 * (1 + sqrt (1 + 4 * b^2)) | |
in (x' + ((b-1)/b') `scale` (x' - x), b') | |
main :: IO () | |
main = do | |
gen <- createSystemRandom | |
mat <- genMat gen | |
let y = mat #> x | |
let experience1 0 x es = pure (x, reverse es) | |
experience1 n x es = do | |
let x' = ista 3.0 y mat x | |
e = norm_2 (y - mat #> x') + norm_1 x' | |
experience1 (n-1) x' (e:es) | |
let experience2 0 (x, b) es = pure (x, reverse es) | |
experience2 n (x, b) es = do | |
let (x', b') = fista 3.0 y mat (x, b) | |
e = norm_2 (y - mat #> x') + norm_1 x' | |
experience2 (n-1) (x', b') (e:es) | |
putStrLn $ "y: " ++ dispf 3 (asRow y) | |
putStrLn $ "A: " ++ dispf 3 mat | |
(result, errors) <- experience1 50 (tr mat #> y) [] | |
-- (result, errors) <- experience2 30 (tr mat #> y, 0) [] | |
putStrLn $ "Result: " ++ dispf 3 (asRow result) | |
r2AxisMain (myaxis errors) | |
where | |
myaxis :: [Double] -> Axis B V2 Double | |
myaxis es = r2Axis &~ do | |
linePlot' $ zip [1..] es |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
resolver: lts-7.11 | |
packages: | |
- '.' | |
extra-deps: | |
- plots-0.1.0.2 | |
flags: {} | |
extra-package-dbs: [] |
Author
lotz84
commented
Nov 29, 2016
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment