Last active
August 29, 2015 14:23
-
-
Save michaelt/eb738a5b6a7524471e61 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-#LANGUAGE TypeOperators, LambdaCase, BangPatterns, RankNTypes #-} | |
-- Needed for the MonadBase instance | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE MultiParamTypeClasses #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
{-#LANGUAGE ScopedTypeVariables #-} | |
module Loop where | |
import Control.Monad.Trans | |
import Control.Monad | |
import Control.Monad.Base | |
import Control.Monad.ST | |
import Data.STRef | |
-- * 'exit' the whole loop. | |
newtype LoopT c e m a = LoopT | |
{ runLoopT :: forall r. -- This universal quantification forces the | |
-- LoopT computation to call one of the | |
-- following continuations. | |
(c -> m r) -- continue | |
-> (e -> m r) -- exit | |
-> (a -> m r) -- return a value | |
-> m r | |
} | |
newtype Loop c e m a = Loop {runLoop :: m (Either c (Either e a))} | |
instance Monad m => Functor (Loop c e m) where | |
fmap f = Loop . liftM (fmap (fmap f)) . runLoop | |
instance Monad m => Applicative (Loop c e m) where pure = return ; (<*>) = ap | |
instance Monad m => Monad (Loop c e m) where | |
return = Loop . return . Right . Right | |
Loop mee >>= f = Loop $ do | |
ee <- mee | |
case ee of | |
Right (Right a) -> runLoop (f a) | |
Right (Left e) -> return (Right (Left e)) | |
Left c -> return (Left c) | |
instance MonadTrans (Loop c e) where | |
lift = Loop . liftM (Right . Right) | |
instance MonadIO m => MonadIO (Loop c e m) where | |
liftIO = lift . liftIO | |
instance Functor (LoopT c e m) where | |
fmap f m = LoopT $ \next fin cont -> runLoopT m next fin (cont . f) | |
instance Applicative (LoopT c e m) where | |
pure a = LoopT $ \_ _ cont -> cont a | |
f1 <*> f2 = LoopT $ \next fin cont -> | |
runLoopT f1 next fin $ \f -> | |
runLoopT f2 next fin (cont . f) | |
instance Monad (LoopT c e m) where | |
return a = LoopT $ \_ _ cont -> cont a | |
m >>= k = LoopT $ \next fin cont -> | |
runLoopT m next fin $ \a -> | |
runLoopT (k a) next fin cont | |
instance MonadTrans (LoopT c e) where | |
lift m = LoopT $ \_ _ cont -> m >>= cont | |
instance MonadIO m => MonadIO (LoopT c e m) where | |
liftIO = lift . liftIO | |
instance MonadBase b m => MonadBase b (LoopT c e m) where | |
liftBase = liftBaseDefault | |
-- | |
instance MonadBase b m => MonadBase b (Loop c e m) where | |
liftBase = liftBaseDefault | |
-- | Skip the rest of the loop body and go to the next iteration. | |
continue :: LoopT () e m a | |
continue = continueWith () | |
continue_ :: Monad m => Loop () e m a | |
continue_ = continueWith_ () | |
-- | Break out of the loop entirely. | |
exit :: LoopT c () m a | |
exit = exitWith () | |
exit_ :: Monad m => Loop c () m a | |
exit_ = exitWith_ () | |
-- | Like 'continue', but return a value from the loop body. | |
continueWith :: c -> LoopT c e m a | |
continueWith c = LoopT $ \next _ _ -> next c | |
continueWith_ :: Monad m => c -> Loop c e m a | |
continueWith_ c = Loop (return (Left c)) | |
-- | Like 'exit', but return a value from the loop as a whole. | |
-- See the documentation of 'iterateLoopT' for an example. | |
exitWith :: e -> LoopT c e m a | |
exitWith e = LoopT $ \_ fin _ -> fin e | |
exitWith_ :: Monad m => e -> Loop c e m a | |
exitWith_ e = Loop (return (Right (Left e))) | |
------------------------------------------------------------------------ | |
-- Looping constructs | |
-- | Call the loop body with each item in the list. | |
-- | |
-- If you do not need to 'continue' or 'exit' the loop, consider using | |
-- 'Control.Monad.forM_' instead. | |
foreach :: Monad m => [a] -> (a -> LoopT c () m c) -> m () | |
foreach list body = loop list | |
where loop [] = return () | |
loop (x:xs) = stepLoopT (body x) (\_ -> loop xs) | |
-- | |
foreach_ :: Monad m => [a] -> (a -> Loop c () m c) -> m () | |
foreach_ list body = loop list | |
where loop [] = return () | |
loop (x:xs) = stepLoop (body x) (\_ -> loop xs) | |
-- | Repeat the loop body while the predicate holds. Like a @while@ loop in C, | |
-- the condition is tested first. | |
while :: Monad m => m Bool -> LoopT c () m c -> m () | |
while cond body = loop | |
where loop = do b <- cond | |
if b then stepLoopT body (\_ -> loop) | |
else return () | |
-- | Like a @do while@ loop in C, where the condition is tested after | |
-- the loop body. | |
-- | |
-- 'doWhile' returns the result of the last iteration. This is possible | |
-- because, unlike 'foreach' and 'while', the loop body is guaranteed to be | |
-- executed at least once. | |
doWhile :: Monad m => LoopT a a m a -> m Bool -> m a | |
doWhile body cond = loop | |
where loop = stepLoopT body $ \a -> do | |
b <- cond | |
if b then loop | |
else return a | |
-- | Execute the loop body once. This is a convenient way to introduce early | |
-- exit support to a block of code. | |
-- | |
-- 'continue' and 'exit' do the same thing inside of 'once'. | |
once :: Monad m => LoopT a a m a -> m a | |
once body = runLoopT body return return return | |
-- | Execute the loop body again and again. The only way to exit 'repeatLoopT' | |
-- is to call 'exit' or 'exitWith'. | |
repeatLoopT :: Monad m => LoopT c e m a -> m e | |
repeatLoopT body = loop | |
where loop = runLoopT body (\_ -> loop) return (\_ -> loop) | |
repeatLoop (Loop mee) = loop | |
where | |
loop = do | |
ee <- mee | |
case ee of | |
Left c -> loop | |
Right (Left e) -> return e | |
Right (Right a) -> loop | |
-- | Call the loop body again and again, passing it the result of the previous | |
-- iteration each time around. The only way to exit 'iterateLoopT' is to call | |
-- 'exit' or 'exitWith'. | |
-- | |
-- Example: | |
-- | |
count :: Int -> IO Int | |
count n = iterateLoopT 0 $ \i -> | |
if i < n | |
then do lift $ print i | |
return $ i+1 | |
else exitWith i | |
-- | |
count_ :: Int -> IO Int | |
count_ n = iterateLoop 0 $ \i -> | |
if i < n | |
then do lift $ print i | |
return $ i+1 | |
else exitWith_ i | |
iterateLoopT :: Monad m => c -> (c -> LoopT c e m c) -> m e | |
iterateLoopT z body = loop z where loop c = stepLoopT (body c) loop | |
iterateLoop :: Monad m => c -> (c -> Loop c e m c) -> m e | |
iterateLoop z body = loop z where loop c = stepLoop (body c) loop | |
stepLoopT :: Monad m => LoopT c e m c -> (c -> m e) -> m e | |
stepLoopT body next = runLoopT body next return next | |
stepLoop :: Monad m => Loop c e m c -> (c -> m e) -> m e | |
stepLoop (Loop mee) f = do | |
ee <- mee | |
case ee of | |
Left c -> f c | |
Right (Left e) -> return e | |
Right (Right c) -> f c | |
------------------------------------------------------------------------ | |
-- Lifting other operations | |
-- | Lift a function like 'Control.Monad.Trans.Reader.local' or | |
-- 'Control.Exception.mask_'. | |
liftLocalLoopT :: Monad m => (forall a. m a -> m a) -> LoopT c e m b -> LoopT c e m b | |
liftLocalLoopT f cb = LoopT $ \next fin cont -> do | |
m <- f $ runLoopT cb (return . next) (return . fin) (return . cont) | |
m |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-#LANGUAGE ScopedTypeVariables, BangPatterns #-} | |
module Main where | |
import Loop | |
import Control.Monad.ST | |
import Data.Array.ST | |
import Data.Array.Unboxed | |
import Data.STRef | |
import Control.Monad | |
import Control.Monad.Trans | |
import Control.Monad.Base | |
countCircularPrimes :: Int -> Int | |
countCircularPrimes e = | |
runST $ do | |
bmp <- newArray (1, e) False :: ST s (STUArray s Int Bool) | |
total <- newSTRef 0 | |
foreach (filter isPrime [2..e]) $ \i -> do | |
-- If i is marked, we've already visited i and its rotations, | |
-- so go on to the next prime. | |
whenM (liftBase $ readArray bmp i) | |
continue | |
let rs = rotateDigits i | |
-- Count the number of unique rotations. We may end up marking a | |
-- number with fewer digits, but that's okay because: | |
-- | |
-- * We've already visited numbers with fewer digits. | |
-- | |
-- * A circular prime will never contain the digit 0. | |
-- | |
-- Thus, any counts affected by truncation will be discarded anyway. | |
count <- liftBase $ newSTRef 0 | |
foreach rs $ \j -> do | |
whenM (liftBase $ readArray bmp j) | |
exit | |
liftBase $ writeArray bmp j True | |
liftBase $ modifySTRef' count (+1) | |
when (all isPrime rs) $ liftBase $ | |
total += count | |
readSTRef total | |
where | |
sieve = mkSieve e | |
isPrime p = sieve ! p | |
-- | |
countCircularPrimes_ :: Int -> Int | |
countCircularPrimes_ e = | |
runST $ do | |
bmp <- newArray (1, e) False :: ST s (STUArray s Int Bool) | |
total <- newSTRef 0 | |
foreach_ (filter isPrime [2..e]) $ \i -> do | |
-- If i is marked, we've already visited i and its rotations, | |
-- so go on to the next prime. | |
whenM (liftBase $ readArray bmp i) | |
continue_ | |
let rs = rotateDigits i | |
-- Count the number of unique rotations. We may end up marking a | |
-- number with fewer digits, but that's okay because: | |
-- | |
-- * We've already visited numbers with fewer digits. | |
-- | |
-- * A circular prime will never contain the digit 0. | |
-- | |
-- Thus, any counts affected by truncation will be discarded anyway. | |
count <- liftBase $ newSTRef 0 | |
foreach_ rs $ \j -> do | |
whenM (liftBase $ readArray bmp j) | |
exit_ | |
liftBase $ writeArray bmp j True | |
liftBase $ modifySTRef' count (+1) | |
when (all isPrime rs) $ liftBase $ | |
total += count | |
readSTRef total | |
where | |
sieve = mkSieve e | |
isPrime p = sieve ! p | |
-- main :: IO () | |
-- main = print $ countCircularPrimes_ 999999 | |
main :: IO () | |
main = do | |
foreach [1..10] $ \(i :: Int) -> do | |
foreach [1..10] $ \(j :: Int) -> do | |
when (j > i) $ | |
lift continue | |
when (i == 2 && j == 2) $ | |
exit | |
when (i == 9 && j == 9) $ | |
lift exit | |
liftBase $ print (i, j) | |
liftBase $ putStrLn "Inner loop finished" | |
putStrLn "Outer loop finished" | |
main_ :: IO () | |
main_ = do | |
foreach_ [1..10] $ \(i :: Int) -> do | |
foreach_ [1..10] $ \(j :: Int) -> do | |
when (j > i) $ | |
lift continue_ | |
when (i == 2 && j == 2) $ | |
exit_ | |
when (i == 9 && j == 9) $ | |
lift exit_ | |
liftBase $ print (i, j) | |
liftBase $ putStrLn "Inner loop finished" | |
putStrLn "Outer loop finished" | |
------------------------------------------------------------------------ | |
-- Helper functions | |
mkSieve :: Int -> UArray Int Bool | |
mkSieve e = runSTUArray $ do | |
bmp <- newArray (2, e) True | |
forM_ [2..e] $ \i -> | |
whenM (readArray bmp i) $ | |
forM_ [i*2, i*3 .. e] $ \j -> | |
writeArray bmp j False | |
return bmp | |
-- | |
-- | Return a list of every rotation of the decimal digits of a number. | |
rotateDigits :: Int -> [Int] | |
rotateDigits start | |
| start < 1 = error "rotateDigits: n < 1" | |
| otherwise = let (_, !r, !rs) = go 1 | |
in rs [r] | |
where | |
go p | p' > start || p' < p = (p, start, id) | |
| otherwise = let (!factor, !r, !rs) = go p' | |
(!n, !d) = r `divMod` 10 | |
in (factor, d * factor + n, rs . (r :)) | |
where | |
p' = p*10 | |
(+=) :: STRef s Int -> STRef s Int -> ST s () | |
(+=) total n = readSTRef n >>= modifySTRef' total . (+) | |
whenM :: Monad m => m Bool -> m () -> m () | |
whenM p m = p >>= \b -> if b then m else return () |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- This solves Google Code Jam 2012 Qualification Problem C "Recycled Numbers" [1]. | |
-- The problem is: given a range of numbers with the same number of digits, | |
-- count how many pairs of them are the same modulo rotation of digits. | |
-- | |
-- [1]: http://code.google.com/codejam/contest/1460488/dashboard#s=p2 | |
{-# LANGUAGE ScopedTypeVariables #-} | |
module Main where | |
import Loop | |
import Control.Applicative ((<$>)) | |
import Control.Monad | |
import Control.Monad.ST | |
import Control.Monad.Trans.Class | |
import Data.Array.ST | |
import Data.STRef | |
main :: IO () | |
main = do | |
-- t <- readLn | |
-- forM_ [1..t] $ \(x :: Int) -> do | |
-- [a, b] <- map read . words <$> getLine | |
let y = recycledNumbers (10000000, 20000000) | |
putStrLn $ "Case #" ++ show 1 ++ ": " ++ show y | |
recycledNumbers :: (Int, Int) -> Int | |
recycledNumbers (lb, ub) | |
| not (1 <= lb && lb <= ub && factor == rotateFactor ub) | |
= error "recycledNumbers: invalid bounds" | |
| otherwise = runST $ do | |
bmp <- newArray (lb, ub) False :: ST s (STUArray s Int Bool) | |
total <- newSTRef 0 | |
forM_ [lb..ub] $ \i -> do | |
count <- newSTRef 0 | |
foreach (iterate rotate i) $ \j -> do | |
when (not $ j >= i && j <= ub) | |
continue | |
whenM (lift $ readArray bmp j) | |
exit | |
lift $ writeArray bmp j True | |
lift $ modifySTRef' count (+1) | |
readSTRef count >>= modifySTRef' total . (+) . numPairs | |
readSTRef total | |
where | |
factor = rotateFactor lb | |
rotate x = let (n, d) = x `divMod` 10 | |
in d*factor + n | |
numPairs n = (n-1) * n `div` 2 | |
-- | |
recycledNumbers_ :: (Int, Int) -> Int | |
recycledNumbers_ (lb, ub) | |
| not (1 <= lb && lb <= ub && factor == rotateFactor ub) | |
= error "recycledNumbers: invalid bounds" | |
| otherwise = runST $ do | |
bmp <- newArray (lb, ub) False :: ST s (STUArray s Int Bool) | |
total <- newSTRef 0 | |
forM_ [lb..ub] $ \i -> do | |
count <- newSTRef 0 | |
foreach_ (iterate rotate i) $ \j -> do | |
when (not $ j >= i && j <= ub) | |
continue_ | |
whenM (lift $ readArray bmp j) | |
exit_ | |
lift $ writeArray bmp j True | |
lift $ modifySTRef' count (+1) | |
readSTRef count >>= modifySTRef' total . (+) . numPairs | |
readSTRef total | |
where | |
factor = rotateFactor lb | |
rotate x = let (n, d) = x `divMod` 10 | |
in d*factor + n | |
numPairs n = (n-1) * n `div` 2 | |
------------------------------------------------------------------------ | |
-- Helper functions | |
-- | Return the power of 10 corresponding to the most significant digit in the | |
-- number. | |
rotateFactor :: Int -> Int | |
rotateFactor n | n < 1 = error "rotateFactor: n < 1" | |
| otherwise = loop 1 | |
where | |
loop p | p' > n = p | |
| p' < p = p -- in case of overflow | |
| otherwise = loop p' | |
where p' = p * 10 | |
(+=) :: STRef s Int -> STRef s Int -> ST s () | |
(+=) total n = readSTRef n >>= modifySTRef' total . (+) | |
whenM :: Monad m => m Bool -> m () -> m () | |
whenM p m = p >>= \b -> if b then m else return () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Control.Monad.Loop
with a non-church-encoded variant.