Last active
November 11, 2023 22:45
-
-
Save sjoerdvisscher/5fe3c3cba928c4b0c112c29860894ed8 to your computer and use it in GitHub Desktop.
Deriving differentiation with linear generics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- https://twitter.com/paf31/status/1362207106703630338 | |
{-# LANGUAGE BlockArguments #-} | |
{-# LANGUAGE DefaultSignatures #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE LinearTypes #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE EmptyCase #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
{-# LANGUAGE DataKinds #-} | |
import Generics.Linear hiding (D) | |
import Generics.Linear.TH | |
import Data.Functor ((<&>)) | |
newtype D f a = D { unD :: forall x. x %1 -> (a -> x) -> f x } | |
deriving Functor | |
-- A value in the context of its one-holed context | |
data InContext f a = InContext a (D f a) | |
deriving Functor | |
hmap :: (forall x. f x %1 -> g x) -> InContext f a -> InContext g a | |
hmap n (InContext a (D k)) = InContext a (D \x ax -> n (k x ax)) | |
class Functor f => Diff f where | |
contexts :: f a -> f (InContext f a) | |
default contexts :: (Generic1 f, Diff (Rep1 f)) => f a -> f (InContext f a) | |
contexts = fmap (hmap to1) . to1 . contexts . from1 | |
instance Diff ((->) r) where | |
contexts f r = InContext (f r) (D \x _ _ -> x) | |
instance Diff Par1 where | |
contexts (Par1 a) = | |
Par1 (InContext a (D \x _ -> Par1 x)) | |
instance (Diff f, Diff g) => Diff (f :.: g) where | |
contexts (Comp1 fg) = Comp1 $ | |
contexts fg <&> \(InContext g (D kf)) -> | |
contexts g <&> \(InContext a (D kg)) -> | |
InContext a (D \x ax -> Comp1 (kf (kg x ax) (fmap ax))) | |
instance (Diff f, Diff g) => Diff (f :*: g) where | |
contexts (f :*: g) = | |
(contexts f <&> \(InContext a (D k)) -> InContext a (D \x ax -> k x ax :*: fmap ax g)) | |
:*: | |
(contexts g <&> \(InContext a (D k)) -> InContext a (D \x ax -> fmap ax f :*: k x ax)) | |
instance (Diff f, Diff g) => Diff (f :+: g) where | |
contexts (L1 f) = L1 (contexts f <&> hmap L1) | |
contexts (R1 f) = R1 (contexts f <&> hmap R1) | |
instance Diff f => Diff (M1 i c f) where | |
contexts (M1 f) = M1 (contexts f <&> hmap M1) | |
instance Diff (K1 i c) where | |
contexts (K1 c) = K1 c | |
instance Diff V1 where | |
contexts = \case | |
instance Diff U1 where | |
contexts U1 = U1 | |
data Example a = Example a Bool [a] | |
deriving (Show, Functor) | |
$(deriveGeneric1 ''Example) | |
instance Diff [] | |
instance Diff Example | |
-- ghci> plugIn <$> contexts (Example 1 True [2, 3]) | |
-- Example (Example 1 True [2,3]) True [Example 1 True [2,3],Example 1 True [2,3]] | |
plugIn :: InContext f a -> f a | |
plugIn (InContext a dfa) = unD dfa a id | |
-- http://blog.sigfpe.com/2006/09/infinitesimal-types.html | |
-- F[x + d] = F[x] + d F'[x] | |
infinitesimal | |
:: (Diff f, Traversable f) | |
=> (forall void. d -> d -> void) -- d^2=0 | |
-> f (Either d a) | |
-> Either (d, D f a) (f a) | |
infinitesimal d2void = traverse f . contexts | |
where | |
f (InContext (Right a) _) = Right a | |
f (InContext (Left d) (D k)) = | |
Left (d, D \x ax -> k x (either (d2void d) ax)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment