Last active
January 31, 2021 08:44
-
-
Save googleson78/1aa592662180b64b1e798030b4cf43ca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE RankNTypes #-} | |
import Control.Monad (replicateM_, when) | |
import Control.Monad.Primitive (PrimMonad, PrimState) | |
import qualified Data.Vector as V | |
import qualified Data.Vector.Mutable as M | |
import System.Random (StdGen, newStdGen, randomR) | |
import Data.Traversable (for) | |
import Data.Functor (void) | |
import Control.Monad.Trans.State.Strict (put, execStateT, get, gets) | |
-- Pure (!) makes a mutable copy, modifies it in-place, freezes it & returns the result | |
shuffle :: StdGen -> V.Vector a -> V.Vector a | |
shuffle gen vector = V.modify (loop gen (V.length vector)) vector | |
where | |
-- Invariant: the first n elements remain to be shuffled | |
-- V.modify runs in the ST monad, so we can't use randomIO. | |
loop :: PrimMonad m => StdGen -> Int -> M.MVector (PrimState m) a -> m () | |
loop gen n v = useGen gen $ for [n, n-1..2] $ \i -> do | |
j <- rand() i | |
M.swap v (i - 1) j | |
-- Runs the Fisher-Yates shuffle k times on a vector of length n | |
testShuffle :: Int -> Int -> IO () | |
testShuffle k n = do | |
when (n > 20) $ | |
putStrLn "Warning: the 64-bit precision of StdGen is not enough to generate all permutations!" | |
replicateM_ k $ do | |
gen <- newStdGen | |
print $ shuffle gen $ V.enumFromN 1 n | |
-- M.IOVector cannot be an instance of Show, | |
-- since accessing (incl. freezing) it is in the IO monad. | |
print' :: Show a => M.IOVector a -> IO () | |
print' v = print =<< V.freeze v | |
-- Cycles through all permutations of a vector of length n | |
testPerm :: Int -> IO () | |
testPerm n = V.thaw (V.fromList [1 .. n]) >>= loop | |
where | |
loop :: M.IOVector Int -> IO () | |
loop vec = do | |
print' vec | |
hasNext <- M.nextPermutation vec | |
when hasNext (loop vec) | |
main :: IO () | |
main = do | |
putStrLn "Testing shuffle (k=5,n=10):" | |
testShuffle 5 10 | |
putStrLn "Testing permutations (n=3):" | |
testPerm 3 | |
rand() m = do | |
(index, gen') <- gets $ randomR (0, m - 1) | |
put gen' | |
pure index | |
useGen gen = void . flip execStateT gen |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment