Created
November 2, 2011 10:10
-
-
Save bjin/1333321 to your computer and use it in GitHub Desktop.
demo: modular arithmetic with modulus parameterized in type system
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE RankNTypes, MultiParamTypeClasses, FlexibleInstances, FlexibleContexts, ScopedTypeVariables #-} | |
module ModP | |
( modP | |
) where | |
import Data.Ratio | |
newtype Dep a b = Dep { unDep :: b } | |
data One = One | |
data D0 a = D0 a | |
data D1 a = D1 a | |
class Integral b => PositiveN p b where | |
p2num :: Dep p b | |
instance Integral b => PositiveN One b where | |
p2num = Dep 1 | |
instance PositiveN p b => PositiveN (D0 p) b where | |
p2num = Dep (unDep (p2num :: Dep p b) * 2) | |
instance PositiveN p b => PositiveN (D1 p) b where | |
p2num = Dep (unDep (p2num :: Dep p b) * 2 + 1) | |
newtype PositiveN p b => ModP p b = ModP { unModP :: b } deriving Eq | |
instance PositiveN p b => Show (ModP p b) where | |
show (ModP r) = show r ++ "+" ++ show (unDep (p2num :: Dep p b)) ++ "Z" | |
instance PositiveN p b => Num (ModP p b) where | |
ModP a + ModP b = ModP ((a + b) `mod` unDep (p2num :: Dep p b)) | |
ModP a - ModP b = ModP ((a - b) `mod` unDep (p2num :: Dep p b)) | |
ModP a * ModP b = ModP ((a * b) `mod` unDep (p2num :: Dep p b)) | |
fromInteger x = ModP (fromInteger x `mod` unDep (p2num :: Dep p b)) | |
abs = undefined | |
signum = undefined | |
extgcd :: Integral a => a -> a -> (a, a, a) | |
extgcd a b | a < 0 = let (g, x, y) = extgcd (-a) b in (g, -x, y) | |
extgcd a b | b < 0 = let (g, x, y) = extgcd a (-b) in (g, x, -y) | |
extgcd a 0 = (a, 1, 0) | |
extgcd a b = (g, x, y - adivb * x) | |
where | |
(adivb, amodb) = a `divMod` b | |
(g, y, x) = extgcd b amodb | |
instance PositiveN p b => Fractional (ModP p b) where | |
recip (ModP a) | g /= 1 = error "ModP: division invalid" | |
| otherwise = ModP (x `mod` n) | |
where | |
n = unDep (p2num :: Dep p b) | |
(g, x, _) = extgcd a n | |
fromRational a = fromInteger (numerator a) / fromInteger (denominator a) | |
num2p' :: (Integral b, Integral i) => i -> (forall p. PositiveN p b => p -> b) -> (forall p. PositiveN p b => p -> b) | |
num2p' n _ | n <= 0 = error "num2p: internal error" | |
num2p' 1 f = f | |
num2p' n f | even n = num2p' (n `div` 2) (f . D0) | |
| otherwise = num2p' (n `div` 2) (f . D1) | |
num2p :: (Integral b, Integral i) => i -> (forall p. PositiveN p b => p -> b) -> b | |
num2p n f = (num2p' n f) One | |
modP :: Integral b => (forall a. Fractional a => a) -> b -> b | |
modP val n | |
| n <= 0 = error "modP: modulus must be positive" | |
| otherwise = num2p n go | |
where | |
go :: forall p b. (PositiveN p b) => p -> b | |
go _ = unModP (val :: ModP p b) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE ScopedTypeVariables, RankNTypes #-} | |
import ModP | |
import Test.QuickCheck | |
import Test.Framework | |
import Test.Framework.Providers.QuickCheck2 | |
import Data.Int (Int64) | |
main = defaultMain tests | |
testOptions = TestOptions { topt_seed = Nothing | |
, topt_maximum_generated_tests = Just 10000 | |
, topt_maximum_unsuitable_generated_tests = Just 100000 | |
, topt_timeout = Nothing | |
} | |
tests = [ plusTestOptions testOptions $ testGroup "basics" | |
[ testProperty "add" prop_add | |
, testProperty "sub" prop_sub | |
, testProperty "mul" prop_mul | |
, testProperty "pow2" prop_pow2 | |
, testProperty "inv" prop_inv | |
, testProperty "int64" prop_int64 | |
] | |
] | |
prop_add :: Integer -> Integer -> Positive Integer -> Bool | |
prop_add x y (Positive n) = (fromInteger x + fromInteger y) `modP` n == (x + y) `mod` n | |
prop_sub :: Integer -> Integer -> Positive Integer -> Bool | |
prop_sub x y (Positive n) = (fromInteger x - fromInteger y) `modP` n == (x - y) `mod` n | |
prop_mul :: Integer -> Integer -> Positive Integer -> Bool | |
prop_mul x y (Positive n) = (fromInteger x * fromInteger y) `modP` n == (x * y) `mod` n | |
prop_pow2 :: Integer -> NonNegative Integer -> Positive Integer -> Bool | |
prop_pow2 x (NonNegative y) (Positive n) = (fromInteger x ^ y) `modP` n == (x ^ y) `mod` n | |
prop_inv :: Integer -> Positive Integer -> Property | |
prop_inv x (Positive n) = gcd (abs x) n == 1 && n > 1 ==> ((1 / fromInteger x) `modP` n) * x `mod` n == 1 | |
prop_int64 :: Integer -> Bool | |
prop_int64 n = n <= 0 || (2 ^ n) `modP` (10^9+7 :: Int64) == fromInteger ((2 ^ n) `mod` (10^9+7)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment