Created
March 9, 2022 14:40
-
-
Save idontgetoutmuch/03796396237e35dd364e111ff8ef931d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveTraversable #-} | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE QuasiQuotes #-} | |
{-# LANGUAGE TypeFamilies #-} | |
import H.Prelude as H | |
import Language.R.QQ | |
import Numeric.AD | |
import Data.Foldable ( Foldable ) | |
import Data.Traversable ( Traversable ) | |
xdata, ydata :: [Double] | |
xdata = [-2,-1.64,-1.33,-0.7,0,0.45,1.2,1.64,2.32,2.9] | |
ydata = [0.699369,0.700462,0.695354,1.03905,1.97389,2.41143,1.91091,0.919576,-0.730975,-1.42001] | |
p1, p2 :: Double | |
p1 = 1.0 | |
p2 = 0.2 | |
data Params a = Params a a | |
deriving (Prelude.Show, Functor, Foldable, Traversable) | |
cost :: Floating a => [a] -> [a] -> Params a -> a | |
cost xs ys (Params p1 p2) = (/ (2 * (fromIntegral $ length xs))) $ | |
sum $ | |
zipWith errSq xs ys | |
where | |
errSq x y = z * z | |
where | |
z = y - (p1 * cos (p2 * x) + p2 * sin (p1 * x)) | |
fitHask :: (Mode a, Ord a, Floating a, Scalar a ~ Double) => [a] -> [a] -> Params a | |
fitHask xs ys = | |
head $ | |
drop 2000 $ | |
gradientDescent (cost (map auto xs) (map auto ys)) (Params (auto p1) (auto p2)) | |
main :: IO () | |
main = do | |
runRegion $ do | |
fitR <- [r| nls(ydata_hs ~ p1*cos(p2*xdata_hs) + p2*sin(p1*xdata_hs) | |
, start=list(p1=p1_hs,p2=p2_hs)) |] | |
_ <- [r| print(coef(fitR_hs)[["p1"]]) |] | |
_ <- [r| print(coef(fitR_hs)[["p2"]]) |] | |
return () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment