Created
April 18, 2017 17:14
-
-
Save cartazio/042c776335141507bd3a2051e3840e59 to your computer and use it in GitHub Desktop.
how mwc samples
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE BangPatterns, CPP, GADTs, FlexibleContexts, ScopedTypeVariables #-} | |
-- | | |
-- Module : System.Random.MWC.Distributions | |
-- Copyright : (c) 2012 Bryan O'Sullivan | |
-- License : BSD3 | |
-- | |
-- Maintainer : [email protected] | |
-- Stability : experimental | |
-- Portability : portable | |
-- | |
-- Pseudo-random number generation for non-uniform distributions. | |
module System.Random.MWC.Distributions | |
( | |
-- * Variates: non-uniformly distributed values | |
-- ** Continuous distributions | |
normal | |
, standard | |
, exponential | |
, truncatedExp | |
, gamma | |
, chiSquare | |
, beta | |
-- ** Discrete distribution | |
, categorical | |
, logCategorical | |
, geometric0 | |
, geometric1 | |
, bernoulli | |
-- ** Multivariate | |
, dirichlet | |
-- * Permutations | |
, uniformPermutation | |
, uniformShuffle | |
, uniformShuffleM | |
-- * References | |
-- $references | |
) where | |
import Prelude hiding (mapM) | |
import Control.Monad (liftM) | |
import Control.Monad.Primitive (PrimMonad, PrimState) | |
import Data.Bits ((.&.)) | |
import Data.Foldable (foldl') | |
#if !MIN_VERSION_base(4,8,0) | |
import Data.Traversable (Traversable) | |
#endif | |
import Data.Traversable (mapM) | |
import Data.Word (Word32) | |
import System.Random.MWC (Gen, uniform, uniformR) | |
import qualified Data.Vector.Unboxed as I | |
import qualified Data.Vector.Generic as G | |
import qualified Data.Vector.Generic.Mutable as M | |
-- Unboxed 2-tuple | |
data T = T {-# UNPACK #-} !Double {-# UNPACK #-} !Double | |
-- | Generate a normally distributed random variate with given mean | |
-- and standard deviation. | |
normal :: PrimMonad m | |
=> Double -- ^ Mean | |
-> Double -- ^ Standard deviation | |
-> Gen (PrimState m) | |
-> m Double | |
{-# INLINE normal #-} | |
normal m s gen = do | |
x <- standard gen | |
return $! m + s * x | |
-- | Generate a normally distributed random variate with zero mean and | |
-- unit variance. | |
-- | |
-- The implementation uses Doornik's modified ziggurat algorithm. | |
-- Compared to the ziggurat algorithm usually used, this is slower, | |
-- but generates more independent variates that pass stringent tests | |
-- of randomness. | |
standard :: PrimMonad m => Gen (PrimState m) -> m Double | |
{-# INLINE standard #-} | |
standard gen = loop | |
where | |
loop = do | |
u <- (subtract 1 . (*2)) `liftM` uniform gen | |
ri <- uniform gen | |
let i = fromIntegral ((ri :: Word32) .&. 127) | |
bi = I.unsafeIndex blocks i | |
bj = I.unsafeIndex blocks (i+1) | |
case () of | |
_| abs u < I.unsafeIndex ratios i -> return $! u * bi | |
| i == 0 -> normalTail (u < 0) | |
| otherwise -> do | |
let x = u * bi | |
xx = x * x | |
d = exp (-0.5 * (bi * bi - xx)) | |
e = exp (-0.5 * (bj * bj - xx)) | |
c <- uniform gen | |
if e + c * (d - e) < 1 | |
then return x | |
else loop | |
normalTail neg = tailing | |
where tailing = do | |
x <- ((/rNorm) . log) `liftM` uniform gen | |
y <- log `liftM` uniform gen | |
if y * (-2) < x * x | |
then tailing | |
else return $! if neg then x - rNorm else rNorm - x | |
-- Constants used by standard/normal. They are floated to the top | |
-- level to avoid performance regression (Bug #16) when blocks/ratios | |
-- are recalculated on each call to standard/normal. It's also | |
-- somewhat difficult to trigger reliably. | |
blocks :: I.Vector Double | |
blocks = (`I.snoc` 0) . I.cons (v/f) . I.cons rNorm . I.unfoldrN 126 go $! T rNorm f | |
where | |
go (T b g) = let !u = T h (exp (-0.5 * h * h)) | |
h = sqrt (-2 * log (v / b + g)) | |
in Just (h, u) | |
v = 9.91256303526217e-3 | |
f = exp (-0.5 * rNorm * rNorm) | |
{-# NOINLINE blocks #-} | |
rNorm :: Double | |
rNorm = 3.442619855899 | |
ratios :: I.Vector Double | |
ratios = I.zipWith (/) (I.tail blocks) blocks | |
{-# NOINLINE ratios #-} | |
-- | Generate an exponentially distributed random variate. | |
exponential :: PrimMonad m | |
=> Double -- ^ Scale parameter | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Double | |
{-# INLINE exponential #-} | |
exponential b gen = do | |
x <- uniform gen | |
return $! - log x / b | |
-- | Generate truncated exponentially distributed random variate. | |
truncatedExp :: PrimMonad m | |
=> Double -- ^ Scale parameter | |
-> (Double,Double) -- ^ Range to which distribution is | |
-- truncated. Values may be negative. | |
-> Gen (PrimState m) -- ^ Generator. | |
-> m Double | |
{-# INLINE truncatedExp #-} | |
truncatedExp scale (a,b) gen = do | |
-- We shift a to 0 and then generate distribution truncated to [0,b-a] | |
-- It's easier | |
let delta = b - a | |
p <- uniform gen | |
return $! a - log ( (1 - p) + p*exp(-scale*delta)) / scale | |
-- | Random variate generator for gamma distribution. | |
gamma :: PrimMonad m | |
=> Double -- ^ Shape parameter | |
-> Double -- ^ Scale parameter | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Double | |
{-# INLINE gamma #-} | |
gamma a b gen | |
| a <= 0 = pkgError "gamma" "negative alpha parameter" | |
| otherwise = mainloop | |
where | |
mainloop = do | |
T x v <- innerloop | |
u <- uniform gen | |
let cont = u > 1 - 0.331 * sqr (sqr x) | |
&& log u > 0.5 * sqr x + a1 * (1 - v + log v) -- Rarely evaluated | |
case () of | |
_| cont -> mainloop | |
| a >= 1 -> return $! a1 * v * b | |
| otherwise -> do y <- uniform gen | |
return $! y ** (1 / a) * a1 * v * b | |
-- inner loop | |
innerloop = do | |
x <- standard gen | |
case 1 + a2*x of | |
v | v <= 0 -> innerloop | |
| otherwise -> return $! T x (v*v*v) | |
-- constants | |
a' = if a < 1 then a + 1 else a | |
a1 = a' - 1/3 | |
a2 = 1 / sqrt(9 * a1) | |
-- | Random variate generator for the chi square distribution. | |
chiSquare :: PrimMonad m | |
=> Int -- ^ Number of degrees of freedom | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Double | |
{-# INLINE chiSquare #-} | |
chiSquare n gen | |
| n <= 0 = pkgError "chiSquare" "number of degrees of freedom must be positive" | |
| otherwise = do x <- gamma (0.5 * fromIntegral n) 1 gen | |
return $! 2 * x | |
-- | Random variate generator for the geometric distribution, | |
-- computing the number of failures before success. Distribution's | |
-- support is [0..]. | |
geometric0 :: PrimMonad m | |
=> Double -- ^ /p/ success probability lies in (0,1] | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Int | |
{-# INLINE geometric0 #-} | |
geometric0 p gen | |
| p == 1 = return 0 | |
| p > 0 && p < 1 = do q <- uniform gen | |
-- FIXME: We want to use log1p here but it will | |
-- introduce dependency on math-functions. | |
return $! floor $ log q / log (1 - p) | |
| otherwise = pkgError "geometric0" "probability out of [0,1] range" | |
-- | Random variate generator for geometric distribution for number of | |
-- trials. Distribution's support is [1..] (i.e. just 'geometric0' | |
-- shifted by 1). | |
geometric1 :: PrimMonad m | |
=> Double -- ^ /p/ success probability lies in (0,1] | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Int | |
{-# INLINE geometric1 #-} | |
geometric1 p gen = do n <- geometric0 p gen | |
return $! n + 1 | |
-- | Random variate generator for Beta distribution | |
beta :: PrimMonad m | |
=> Double -- ^ alpha (>0) | |
-> Double -- ^ beta (>0) | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Double | |
{-# INLINE beta #-} | |
beta a b gen = do | |
x <- gamma a 1 gen | |
y <- gamma b 1 gen | |
return $! x / (x+y) | |
-- | Random variate generator for Dirichlet distribution | |
dirichlet :: (PrimMonad m, Traversable t) | |
=> t Double -- ^ container of parameters | |
-> Gen (PrimState m) -- ^ Generator | |
-> m (t Double) | |
{-# INLINE dirichlet #-} | |
dirichlet t gen = do | |
t' <- mapM (\x -> gamma x 1 gen) t | |
let total = foldl' (+) 0 t' | |
return $ fmap (/total) t' | |
-- | Random variate generator for Bernoulli distribution | |
bernoulli :: PrimMonad m | |
=> Double -- ^ Probability of success (returning True) | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Bool | |
{-# INLINE bernoulli #-} | |
bernoulli p gen = (<p) `liftM` uniform gen | |
-- | Random variate generator for categorical distribution. | |
-- | |
-- Note that if you need to generate a lot of variates functions | |
-- "System.Random.MWC.CondensedTable" will offer better | |
-- performance. If only few is needed this function will faster | |
-- since it avoids costs of setting up table. | |
categorical :: (PrimMonad m, G.Vector v Double) | |
=> v Double -- ^ List of weights [>0] | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Int | |
{-# INLINE categorical #-} | |
categorical v gen | |
| G.null v = pkgError "categorical" "empty weights!" | |
| otherwise = do | |
let cv = G.scanl1' (+) v | |
p <- (G.last cv *) `liftM` uniform gen | |
return $! case G.findIndex (>=p) cv of | |
Just i -> i | |
Nothing -> pkgError "categorical" "bad weights!" | |
-- | Random variate generator for categorical distribution where the | |
-- weights are in the log domain. It's implemented in terms of | |
-- 'categorical'. | |
logCategorical :: (PrimMonad m, G.Vector v Double) | |
=> v Double -- ^ List of logarithms of weights | |
-> Gen (PrimState m) -- ^ Generator | |
-> m Int | |
{-# INLINE logCategorical #-} | |
logCategorical v gen | |
| G.null v = pkgError "logCategorical" "empty weights!" | |
| otherwise = categorical (G.map (exp . subtract m) v) gen | |
where | |
m = G.maximum v | |
-- | Random variate generator for uniformly distributed permutations. | |
-- It returns random permutation of vector /[0 .. n-1]/. | |
-- | |
-- This is the Fisher-Yates shuffle | |
uniformPermutation :: forall m v. (PrimMonad m, G.Vector v Int) | |
=> Int | |
-> Gen (PrimState m) | |
-> m (v Int) | |
{-# INLINE uniformPermutation #-} | |
uniformPermutation n gen | |
| n < 0 = pkgError "uniformPermutation" "size must be >=0" | |
| otherwise = uniformShuffle (G.generate n id :: v Int) gen | |
-- | Random variate generator for a uniformly distributed shuffle (all | |
-- shuffles are equiprobable) of a vector. It uses Fisher-Yates | |
-- shuffle algorithm. | |
uniformShuffle :: (PrimMonad m, G.Vector v a) | |
=> v a | |
-> Gen (PrimState m) | |
-> m (v a) | |
{-# INLINE uniformShuffle #-} | |
uniformShuffle vec gen | |
| G.length vec <= 1 = return vec | |
| otherwise = do | |
mvec <- G.thaw vec | |
uniformShuffleM mvec gen | |
G.unsafeFreeze mvec | |
-- | In-place uniformly distributed shuffle (all shuffles are | |
-- equiprobable)of a vector. | |
uniformShuffleM :: (PrimMonad m, M.MVector v a) | |
=> v (PrimState m) a | |
-> Gen (PrimState m) | |
-> m () | |
{-# INLINE uniformShuffleM #-} | |
uniformShuffleM vec gen | |
| M.length vec <= 1 = return () | |
| otherwise = loop 0 | |
where | |
n = M.length vec | |
lst = n-1 | |
loop i | i == lst = return () | |
| otherwise = do j <- uniformR (i,lst) gen | |
M.unsafeSwap vec i j | |
loop (i+1) | |
sqr :: Double -> Double | |
sqr x = x * x | |
{-# INLINE sqr #-} | |
pkgError :: String -> String -> a | |
pkgError func msg = error $ "System.Random.MWC.Distributions." ++ func ++ | |
": " ++ msg | |
-- $references | |
-- | |
-- * Doornik, J.A. (2005) An improved ziggurat method to generate | |
-- normal random samples. Mimeo, Nuffield College, University of | |
-- Oxford. <http://www.doornik.com/research/ziggurat.pdf> | |
-- | |
-- * Thomas, D.B.; Leong, P.G.W.; Luk, W.; Villasenor, J.D. | |
-- (2007). Gaussian random number generators. | |
-- /ACM Computing Surveys/ 39(4). | |
-- <http://www.cse.cuhk.edu.hk/~phwl/mt/public/archives/papers/grng_acmcs07.pdf> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment