Created
May 24, 2018 04:47
-
-
Save tonyday567/0a72d9769e7fc5450784b6539dd95fa3 to your computer and use it in GitHub Desktop.
numhask-backprop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env stack | |
-- stack --install-ghc runghc --resolver lts-11.9 --package backprop-0.2.2.0 --package numhask-prelude-0.0.4.1 --package numhask-0.2.1.0 -- -Wall -O2 | |
{-# LANGUAGE NoImplicitPrelude #-} | |
{-# LANGUAGE OverloadedStrings #-} | |
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE DeriveGeneric #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE PatternSynonyms #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE ViewPatterns #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
{-# LANGUAGE ConstraintKinds #-} | |
{-# OPTIONS_GHC -Wall #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE TypeFamilies #-} | |
module NumHask.Backprop where | |
import NumHask.Prelude as NH | |
import qualified Numeric.Backprop as IBP | |
import Numeric.Backprop.Explicit as BP | |
newtype NH a = NH { unnh :: a} deriving (Eq, Ord, AdditiveMagma, AdditiveAssociative, AdditiveCommutative, AdditiveUnital, AdditiveIdempotent, Additive, AdditiveInvertible, AdditiveGroup, MultiplicativeMagma, MultiplicativeUnital, MultiplicativeAssociative, MultiplicativeCommutative, MultiplicativeIdempotent, Multiplicative, MultiplicativeInvertible, MultiplicativeGroup, Distribution, Semiring, Ring, CRing, StarSemiring, KleeneAlgebra, InvolutiveRing, Semifield, Field, ExpField, QuotientField, UpperBoundedField, LowerBoundedField, TrigField, Signed, Integral, ToInteger, FromInteger) | |
-- Normed, Metric, Epsilon | |
newtype BVarNH s a = BVarNH { unNH :: (Additive a, MultiplicativeUnital a, Reifies s W) => BVar s (NH a)} deriving (AdditiveAssociative, AdditiveCommutative, AdditiveIdempotent, MultiplicativeAssociative, MultiplicativeCommutative, MultiplicativeIdempotent, Distribution, Semiring, Ring, CRing, Semifield, Field, KleeneAlgebra) | |
instance (Eq a, Additive a, MultiplicativeUnital a, Reifies s W) => Eq (BVarNH s a) where | |
(==) (BVarNH a) (BVarNH b) = a == b | |
instance (Ord a, Additive a, MultiplicativeUnital a, Reifies s W) => Ord (BVarNH s a) where | |
(>=) (BVarNH a) (BVarNH b) = a >= b | |
(<=) (BVarNH a) (BVarNH b) = a <= b | |
-- QuotientField, UpperBoundedField, LowerBoundedField, Integral, ToInteger, FromInteger | |
-- * Backprop instance for a NH wrapped number | |
instance (Additive a, MultiplicativeUnital a) => Backprop (NH a) where | |
zero _ = NH.zero | |
one _ = NH.one | |
add = (NH.+) | |
-- * operators | |
plusOp :: AdditiveMagma a => Op '[a, a] a | |
plusOp = op2 $ \x y -> (x `plus` y, \g -> (g, g)) | |
negateOp :: (AdditiveInvertible a) => Op '[a] a | |
negateOp = op1 $ \x -> (negate x, negate) | |
timesOp :: MultiplicativeMagma a => Op '[a, a] a | |
timesOp = op2 $ \x y -> (x `times` y, \g -> (y `times` g, x `times` g)) | |
recipOp :: (AdditiveInvertible a, MultiplicativeGroup a) => Op '[a] a | |
recipOp = op1 $ \x -> (recip x, (/(x*x)) . negate) | |
signOp :: (Signed a, AdditiveUnital a) => Op '[a] a | |
signOp = op1 $ \x -> (sign x, const NH.zero) | |
absOp :: (Signed a) => Op '[a] a | |
absOp = op1 $ \x -> (abs x, (`times` sign x)) | |
starOp :: (StarSemiring a) => Op '[a] a | |
starOp = op1 $ \x -> (star x, plus') | |
plus'Op :: (StarSemiring a) => Op '[a] a | |
plus'Op = op1 $ \x -> (plus' x, (`times` star x)) | |
adjOp :: (InvolutiveRing a) => Op '[a] a | |
adjOp = op1 $ \x -> (adj x, adj) | |
expOp :: ExpField a => Op '[a] a | |
expOp = op1 $ \x -> (exp x, (exp x *)) | |
logOp :: ExpField a => Op '[a] a | |
logOp = op1 $ \x -> (log x, (/x)) | |
sinOp :: TrigField a => Op '[a] a | |
sinOp = op1 $ \x -> (sin x, (* cos x)) | |
cosOp :: TrigField a => Op '[a] a | |
cosOp = op1 $ \x -> (cos x, (* (negate (sin x)))) | |
asinOp :: (ExpField a, TrigField a) => Op '[a] a | |
asinOp = op1 $ \x -> (asin x, (/ sqrt(NH.one - x*x))) | |
acosOp :: (ExpField a, TrigField a) => Op '[a] a | |
acosOp = op1 $ \x -> (acos x, (/ sqrt (NH.one - x*x)) . negate) | |
atanOp :: TrigField a => Op '[a] a | |
atanOp = op1 $ \x -> (atan x, (/ (x*x + NH.one))) | |
sinhOp :: TrigField a => Op '[a] a | |
sinhOp = op1 $ \x -> (sinh x, (* cosh x)) | |
coshOp :: TrigField a => Op '[a] a | |
coshOp = op1 $ \x -> (cosh x, (* sinh x)) | |
tanhOp :: (TrigField a, ExpField a) => Op '[a] a | |
tanhOp = op1 $ \x -> (tanh x, (/ cosh x ** (NH.one + NH.one))) | |
asinhOp :: (TrigField a, ExpField a) => Op '[a] a | |
asinhOp = op1 $ \x -> (asinh x, (/ sqrt (x*x + NH.one))) | |
acoshOp :: (TrigField a, ExpField a) => Op '[a] a | |
acoshOp = op1 $ \x -> (acosh x, (/ sqrt (x*x - NH.one))) | |
atanhOp :: (TrigField a) => Op '[a] a | |
atanhOp = op1 $ \x -> (atanh x, (/ (NH.one - x*x))) | |
type family HKD f a where | |
HKD Identity a = a | |
HKD f a = f a | |
data FracType' a f = FT { ftI :: HKD f Integer, ftR :: HKD f a} deriving Generic | |
type FracType a = FracType' a Identity | |
instance (Backprop a) => Backprop (FracType' a Identity) | |
properFractionOp :: (UpperBoundedField a, FromInteger a, QuotientField a) => Op '[a] (FracType a) | |
properFractionOp = op1 $ \x -> | |
(,) (let (i,r) = properFraction x in FT i r) (\(FT i r) -> if r == NH.zero then nan else fromInteger i) | |
-- | fixme: | |
extractInteger :: (Reifies s W, Backprop a) => BVar s (FracType a) -> (Integer, BVar s a) | |
extractInteger (IBP.splitBV -> FT i r) = undefined -- (i,r) | |
-- numhask classes | |
instance ( ) => AdditiveMagma (BVarNH s a) where | |
plus (BVarNH a) (BVarNH b) = BVarNH $ | |
(liftOp2 addFunc addFunc zeroFunc plusOp) a b | |
instance ( ) => AdditiveUnital (BVarNH s a) where | |
zero = NH.zero | |
instance ( ) => Additive (BVarNH s a) | |
instance (AdditiveInvertible a) => AdditiveInvertible (BVarNH s a) where | |
negate (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc negateOp) a | |
instance (AdditiveInvertible a) => AdditiveGroup (BVarNH s a) | |
instance ( ) => MultiplicativeMagma (BVarNH s a) where | |
times (BVarNH a) (BVarNH b) = BVarNH $ | |
(liftOp2 addFunc addFunc zeroFunc timesOp) a b | |
instance ( ) => MultiplicativeUnital (BVarNH s a) where | |
one = NH.one | |
instance ( ) => Multiplicative (BVarNH s a) | |
instance (AdditiveInvertible a, MultiplicativeGroup a) => | |
MultiplicativeInvertible (BVarNH s a) where | |
recip (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc recipOp) a | |
instance (AdditiveInvertible a, MultiplicativeGroup a) => | |
MultiplicativeGroup (BVarNH s a) | |
instance (Signed a) => Signed (BVarNH s a) where | |
sign (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc signOp) a | |
abs (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc absOp) a | |
instance (StarSemiring a) => StarSemiring (BVarNH s a) where | |
plus' (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc plus'Op) a | |
star (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc starOp) a | |
instance (InvolutiveRing a) => InvolutiveRing (BVarNH s a) where | |
adj (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc adjOp) a | |
instance (ExpField a, AdditiveInvertible a, MultiplicativeGroup a) => | |
ExpField (BVarNH s a) where | |
exp (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc expOp) a | |
log (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc logOp) a | |
instance | |
( Reifies s W -- fixme: why was this needed here? | |
, ExpField a | |
, TrigField a | |
, AdditiveInvertible a | |
, MultiplicativeGroup a) => | |
TrigField (BVarNH s a) where | |
pi = NH.pi | |
sin (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc sinOp) a | |
cos (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc cosOp) a | |
asin (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc asinOp) a | |
acos (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc acosOp) a | |
atan (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc atanOp) a | |
sinh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc sinhOp) a | |
cosh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc coshOp) a | |
asinh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc asinhOp) a | |
acosh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc acoshOp) a | |
atanh (BVarNH a) = BVarNH $ (liftOp1 addFunc zeroFunc atanhOp) a | |
instance (UpperBoundedField a, FromInteger a, Reifies s W, QuotientField a, AdditiveInvertible a, MultiplicativeGroup a) => | |
QuotientField (BVarNH s a) where | |
properFraction (BVarNH a) = (\(x,y) -> (x, BVarNH y)) $ extractInteger $ (liftOp1 addFunc zeroFunc properFractionOp) a | |
main :: IO () | |
main = pure () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment