Last active
April 30, 2024 18:49
-
-
Save AndrasKovacs/e156ae66b8c28b1b84abe6b483ea20ec to your computer and use it in GitHub Desktop.
1brc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# language | |
BlockArguments | |
, CPP | |
, LambdaCase | |
, MagicHash | |
, PatternSynonyms | |
, Strict | |
, TypeApplications | |
, UnboxedTuples | |
, ViewPatterns | |
#-} | |
{-# options_ghc | |
-Wall | |
-Wno-missing-signatures | |
-Wno-name-shadowing | |
#-} | |
{- cabal: | |
build-depends: base >= 4.19, bytestring, mmap, async | |
default-language: GHC2021 | |
ghc-options: -Wall -O2 -fllvm -rtsopts -threaded -split-sections | |
-} | |
-- more debugging: | |
-- ghc -O2 -fllvm -rtsopts -threaded -split-sections -ddump-simpl -dsuppress-all | |
-- -dno-suppress-type-signatures -ddump-to-file -fforce-recomp | |
-- CONFIGURATION | |
-------------------------------------------------------------------------------- | |
-------------------------------------------------------------------------------- | |
-- display output | |
#define DISPLAY_OUTPUT | |
-- should be power of 2, minimum 16384 | |
-- #define TABLE_SIZE 131072 | |
-- #define TABLE_SIZE 65536 | |
#define TABLE_SIZE 32768 | |
-- #define TABLE_SIZE 16384 | |
-------------------------------------------------------------------------------- | |
-------------------------------------------------------------------------------- | |
import Control.Concurrent | |
import Control.Monad | |
import Data.Bits | |
import Foreign.Marshal.Alloc | |
import GHC.Exts | |
import GHC.IO | |
import GHC.Word | |
import System.IO.MMap | |
import qualified Data.ByteString.Builder as BB | |
import qualified Data.ByteString.Lazy.Char8 as LC8 | |
#ifdef DISPLAY_OUTPUT | |
import Data.List | |
import Text.Printf | |
import System.IO hiding (withFile) | |
#endif | |
-- Random common functions | |
-------------------------------------------------------------------------------- | |
mapConcurrently :: (a -> IO b) -> [a] -> IO [b] | |
mapConcurrently f xs = do | |
caps <- getNumCapabilities | |
unless (caps == length xs) $ error "wrong number of capabilities" | |
vs <- forM (zip [0..] xs) \(i, x) -> do | |
v <- newEmptyMVar | |
v <$ forkOn i do | |
yield | |
y <- f x | |
putMVar v y | |
forM vs takeMVar | |
fi :: (Integral a, Num b) => a -> b | |
fi = fromIntegral; {-# inline fi #-} | |
sl :: Bits a => a -> Int -> a | |
sl = unsafeShiftL | |
sr :: Bits a => a -> Int -> a | |
sr = unsafeShiftR | |
isrl :: Int -> Int -> Int | |
isrl (I# x) (I# y) = I# (uncheckedIShiftRL# x y) | |
max' :: Int -> Int -> Int | |
max' a b = let diff = a - b in a - (diff .&. sr diff 63) | |
min' :: Int -> Int -> Int | |
min' a b = let diff = a - b in b + (diff .&. sr diff 63) | |
plusAddr :: Addr# -> Int -> Addr# | |
plusAddr p (I# x) = plusAddr# p x | |
int2Addr :: Int -> Addr# | |
int2Addr (I# x) = int2Addr# x | |
addr2Int :: Addr# -> Int | |
addr2Int p = I# (addr2Int# p) | |
eqI :: Int -> Int -> Int | |
eqI (I# x) (I# y) = I# (x ==# y) | |
readI :: Addr# -> IO Int | |
readI p = IO \s -> case readIntOffAddr# p 0# s of (# s, x #) -> (# s, I# x #) | |
writeI :: Addr# -> Int -> IO () | |
writeI p (I# x) = IO \s -> case writeIntOffAddr# p 0# x s of s -> (# s, () #) | |
-- Generic buffers | |
-------------------------------------------------------------------------------- | |
data Buffer = Buffer {_ptr :: Addr#, len :: Int} | |
plus :: Buffer -> Int -> Buffer | |
plus (Buffer p l) (I# x) = Buffer (plusAddr# p x) (l - I# x) | |
memset :: Buffer -> Word8 -> IO () | |
memset (Buffer p (I# l)) (W8# x) = IO \s -> | |
case setAddrRange# p l (word2Int# (word8ToWord# x)) s of s -> (# s, () #) | |
withFile :: FilePath -> (Buffer -> IO a) -> IO a | |
withFile path k = mmapWithFilePtr path ReadOnly Nothing \(Ptr p, l) -> k (Buffer p l) | |
{-# inline withFile #-} | |
indexW8 :: Buffer -> Int -> Word8 | |
indexW8 (Buffer p _) (I# x) = W8# (indexWord8OffAddr# p x) | |
indexW32 :: Buffer -> Int -> Word32 | |
indexW32 (Buffer p _) (I# x) = W32# (indexWord32OffAddr# p x) | |
indexW :: Buffer -> Int -> Word | |
indexW (Buffer p _) (I# x) = W# (indexWordOffAddr# p x) | |
indexI :: Buffer -> Int -> Int | |
indexI (Buffer p _) (I# x) = I# (indexIntOffAddr# p x) | |
getW8 = (`indexW8` 0) | |
getW32 = (`indexW32` 0) | |
getW = (`indexW` 0) | |
instance Eq Buffer where | |
Buffer p l == Buffer p' l' = l == l' && go p p' l where | |
buf p = Buffer p l | |
go p p' l | |
| l >= 8 = getW (buf p) == getW (buf p') && go (plusAddr# p 8#) (plusAddr# p' 8#) (l - 8) | |
| l >= 4 = getW32 (buf p) == getW32 (buf p') && go (plusAddr# p 4#) (plusAddr# p' 4#) (l - 4) | |
| l == 0 = True | |
| True = getW8 (buf p) == getW8 (buf p') && go (plusAddr# p 1#) (plusAddr# p' 1#) (l - 1) | |
{-# inline (==) #-} | |
foldedMul :: Word -> Word -> Word | |
foldedMul (W# x) (W# y) = case timesWord2# x y of (# hi, lo #) -> W# (xor# hi lo) | |
salt :: Word | |
salt = 3032525626373534813 | |
combine :: Word -> Word -> Word | |
combine x y = foldedMul (xor x y) 11400714819323198549 | |
hashBuffer :: Buffer -> Word | |
hashBuffer p = go p salt where | |
go p acc | |
| len p >= 8 = go (plus p 8) (combine (getW p) acc) | |
| len p >= 4 = go (plus p 4) (combine (fromIntegral (getW32 p)) acc) | |
| len p == 0 = acc | |
| otherwise = go (plus p 1) (combine (fromIntegral (getW8 p)) acc) | |
buildBuffer :: Buffer -> BB.Builder | |
buildBuffer b | len b == 0 = mempty | |
buildBuffer b = BB.word8 (getW8 b) <> buildBuffer (plus b 1) | |
-- printBuffer :: Buffer -> IO () | |
-- printBuffer = BB.hPutBuilder stdout . buildBuffer | |
instance Show Buffer where | |
show x = | |
LC8.unpack $ BB.toLazyByteString $ buildBuffer x | |
instance Ord Buffer where | |
compare x x' = compare (show x) (show x') | |
-- Short buffer | |
-------------------------------------------------------------------------------- | |
-- Unboxed buffer containing at most 23 bytes. The first field is the length, | |
-- the rest is the payload. The 24-th byte in the payload is always zeroed out. | |
data ShortBuffer = ShortBuffer# Int Int Int Int | |
instance Eq ShortBuffer where | |
ShortBuffer# _ a b c == ShortBuffer# _ a' b' c' = | |
(eqI a a' .&. eqI b b' .&. eqI c c') == 1 | |
{-# inline (==) #-} | |
hashShortBuffer :: ShortBuffer -> Word | |
hashShortBuffer (ShortBuffer# _ a b c) = | |
(salt `combine` fi a) `combine` (fi b `combine` fi c) | |
buildShortBuffer :: ShortBuffer -> BB.Builder | |
buildShortBuffer (ShortBuffer# l a b c) = | |
BB.lazyByteString $ LC8.take (fi l) $ BB.toLazyByteString $ | |
BB.int64LE (fi c) <> BB.int64LE (fi b) <> BB.int64LE (fi a) | |
instance Show ShortBuffer where | |
show = LC8.unpack . BB.toLazyByteString . buildShortBuffer | |
instance Ord ShortBuffer where | |
compare (ShortBuffer# _ a b c) (ShortBuffer# _ a' b' c') = | |
let sw (I# x) = W# (byteSwap# (int2Word# x)) | |
in compare (sw c) (sw c') <> compare (sw b) (sw b') <> compare (sw a) (sw a') | |
-- Unboxed sum of short and standard buffers. | |
-------------------------------------------------------------------------------- | |
data SLBuffer = SLB# Int Int Int | |
isEmptySLB :: SLBuffer -> Bool | |
isEmptySLB (SLB# a _ _) = a == 0 | |
unpackSLB# :: SLBuffer -> (# ShortBuffer | Buffer #) | |
unpackSLB# (SLB# a b c) = | |
let l = a .&. 255 in | |
if l <= 23 then (# ShortBuffer# l (isrl a 8) b c | #) | |
else (# | Buffer (int2Addr b) a #) | |
pattern ShortBuffer :: ShortBuffer -> SLBuffer | |
pattern ShortBuffer buf <- (unpackSLB# -> (# buf | #)) where | |
ShortBuffer (ShortBuffer# len a b c) = SLB# (sl a 8 .|. len) b c | |
pattern LongBuffer :: Buffer -> SLBuffer | |
pattern LongBuffer buf <- (unpackSLB# -> (# | buf #)) where | |
LongBuffer (Buffer p l) = SLB# l (addr2Int p) 0 | |
{-# complete ShortBuffer, LongBuffer #-} | |
instance Eq SLBuffer where | |
ShortBuffer b == ShortBuffer b' = b == b' | |
LongBuffer b == LongBuffer b' = b == b' | |
_ == _ = False | |
{-# inline (==) #-} | |
-- Try to pack a Buffer into a short one. | |
packBuffer :: Buffer -> SLBuffer | |
packBuffer b = | |
let l = len b | |
ix = indexI b | |
mask l = isrl (-1) (64 - sl l 3) in | |
if l <= 8 then ShortBuffer (ShortBuffer# l 0 0 (ix 0 .&. mask l)) | |
else if l <= 16 then ShortBuffer (ShortBuffer# l 0 (ix 1 .&. mask (l - 8)) (ix 0)) | |
else if l <= 23 then ShortBuffer (ShortBuffer# l (ix 2 .&. mask (l - 16)) (ix 1) (ix 0)) | |
else LongBuffer b | |
hashSLB :: SLBuffer -> Word | |
hashSLB (ShortBuffer b) = hashShortBuffer b | |
hashSLB (LongBuffer b) = hashBuffer b | |
buildSLB :: SLBuffer -> BB.Builder | |
buildSLB (ShortBuffer b) = buildShortBuffer b | |
buildSLB (LongBuffer b) = buildBuffer b | |
instance Show SLBuffer where | |
show = LC8.unpack . BB.toLazyByteString . buildSLB | |
instance Ord SLBuffer where | |
compare (ShortBuffer b) (ShortBuffer b') = compare b b' | |
compare b b' = compare (show b) (show b') | |
-- Branchless scanning for bytes in words. | |
-------------------------------------------------------------------------------- | |
#define SCAN_MASK(hex) 0x/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex | |
-- Given a hexadecimal byte, generate the (Word -> Int) function which returns the | |
-- index of the rightmost occurrence of the byte, or returns 8 if the byte does not | |
-- occur. | |
#define BYTE_INDEX(hex) (\(x :: Word) -> case xor x SCAN_MASK(hex) of \ | |
x -> case (x - 0x0101010101010101) .&. complement x .&. 0x8080808080808080 of \ | |
x -> countTrailingZeros x `sr` 3) | |
-- Hash table of measurements | |
-------------------------------------------------------------------------------- | |
data Val = Val { | |
_min :: Int | |
, _max :: Int | |
, _cnt :: Int | |
, _total :: Int | |
} | |
data Entry = Entry { | |
_key :: {-# unpack #-} SLBuffer | |
, _val :: {-# unpack #-} Val | |
} | |
-- size of entry in bytes (includes padding to 64 bytes!) | |
entrySize :: Int | |
entrySize = 8 * 8 | |
tableMask :: Int | |
tableMask = TABLE_SIZE - 1 | |
tableBytes :: Int | |
tableBytes = TABLE_SIZE * entrySize | |
type Table = Addr# | |
initTables :: [Buffer] -> ([(Buffer, Ptr Word8)] -> IO a) -> IO a | |
initTables bs f = do | |
let l = TABLE_SIZE * entrySize | |
let go [] acc = f acc | |
go (b:bs) acc = allocaBytesAligned l entrySize \p@(Ptr p') -> do | |
memset (Buffer p' l) 0 | |
go bs ((b, p):acc) | |
go bs [] | |
-- read entry from a *byte* offset | |
readEntry :: Table -> Int -> IO Entry | |
readEntry p i = case plusAddr p i of | |
p -> do | |
a <- readI p | |
b <- readI (plusAddr p 8) | |
c <- readI (plusAddr p 16) | |
d <- readI (plusAddr p 24) | |
e <- readI (plusAddr p 32) | |
f <- readI (plusAddr p 40) | |
g <- readI (plusAddr p 48) | |
pure $ Entry (SLB# a b c) (Val d e f g) | |
-- write entry to a *byte* offset | |
writeEntry :: Table -> Int -> Entry -> IO () | |
writeEntry p i (Entry (SLB# a b c) (Val d e f g)) = case plusAddr p i of | |
p -> do | |
writeI p a | |
writeI (plusAddr p 8) b | |
writeI (plusAddr p 16) c | |
writeI (plusAddr p 24) d | |
writeI (plusAddr p 32) e | |
writeI (plusAddr p 40) f | |
writeI (plusAddr p 48) g | |
newVal :: Int -> Val | |
newVal temp = Val temp temp 1 temp | |
updateEntry :: Entry -> Val -> Entry | |
updateEntry (Entry k (Val mi ma cn to)) (Val mi' ma' cn' to') | |
= Entry k (Val (min' mi mi') (max' ma ma') (cn + cn') (to + to')) | |
forTable :: Table -> (Entry -> IO ()) -> IO () | |
forTable t f = do | |
let go ix | ix == tableBytes = pure () | |
go ix = do | |
e@(Entry k _) <- readEntry t ix | |
if isEmptySLB k then do | |
go (ix + entrySize) | |
else do | |
f e | |
go (ix + entrySize) | |
go 0 | |
{-# inline forTable #-} | |
updateTable :: Table -> Entry -> IO () | |
updateTable tbl e@(Entry key val) = do | |
let go ix | ix == tableBytes = go 0 | |
go ix = do | |
olde@(Entry oldkey _) <- readEntry tbl ix | |
if isEmptySLB oldkey then do | |
writeEntry tbl ix e | |
else if key == oldkey then do | |
writeEntry tbl ix (updateEntry olde val) | |
else do | |
go (ix + entrySize) | |
go ((fi (hashSLB key) .&. tableMask) * entrySize) | |
parse :: Table -> Buffer -> IO () | |
parse _ b | len b == 0 = do | |
pure () | |
parse tbl b = do | |
-- scan for semicolon | |
let findSemi :: Int -> Buffer -> Int | |
findSemi i b = case BYTE_INDEX(3B) (getW b) of | |
8 -> findSemi (i + 8) (plus b 8) | |
i' -> i + i' | |
let keylen = findSemi 0 b | |
let key = packBuffer $ b {len = keylen} | |
b <- pure $ plus b (keylen + 1) | |
let digit :: Word8 -> Int | |
digit x = fi x - 48 | |
let join :: Buffer -> Int -> IO () | |
join b temp = do | |
updateTable tbl (Entry key (newVal temp)) | |
parse tbl b | |
case getW8 b of | |
-- '-' | |
45 -> do | |
let d1 = getW8 (plus b 1) | |
case getW8 (plus b 2) of | |
-- '.' so the next must be digit | |
46 -> do | |
let d2 = getW8 (plus b 3) | |
join (plus b 5) ((-10)*(digit d1) - digit d2) | |
-- digit, so the next must be '.' and then digit | |
d2 -> do | |
let d3 = getW8 (plus b 4) | |
join (plus b 6) ((-100)*(digit d1) - 10*(digit d2) - digit d3) | |
-- a digit | |
d1 -> case getW8 (plus b 1) of | |
-- '.', so the next must be digit | |
46 -> do | |
let d2 = getW8 (plus b 2) | |
join (plus b 4) (10*digit d1 + digit d2) | |
-- another digit, so the next must be '.', and then digit | |
d2 -> do | |
let d3 = getW8 (plus b 3) | |
join (plus b 5) (100*digit d1 + 10*digit d2 + digit d3) | |
-- Split file to THREAD_NUM buffers | |
-------------------------------------------------------------------------------- | |
splitBuffer :: Int -> Buffer -> [Buffer] | |
splitBuffer num_threads b = let | |
chunkSize = div (len b) num_threads | |
go b | len b <= chunkSize = | |
[b] | |
go b = let | |
findNewl i b = case BYTE_INDEX(0A) (getW b) of | |
8 -> findNewl (i + 8) (plus b 8) | |
i' -> i + i' | |
keylen = findNewl 0 (plus b chunkSize) | |
chunkSize' = chunkSize + keylen + 1 | |
rest = go (plus b chunkSize') | |
in | |
Buffer (_ptr b) chunkSize' : rest | |
in go b | |
#ifdef DISPLAY_OUTPUT | |
tableToList :: Table -> IO [Entry] | |
tableToList tbl = do | |
let go ix | ix == tableBytes = pure [] | |
go ix = do | |
e@(Entry k _) <- readEntry tbl ix | |
if isEmptySLB k then do | |
go (ix + entrySize) | |
else do | |
es <- go (ix + entrySize) | |
pure (e:es) | |
go 0 | |
displayEntries :: [Entry] -> BB.Builder | |
displayEntries es = BB.char8 '{' <> go es <> BB.char8 '}' where | |
f $$! x = f x; infixl 8 $$! | |
goEntry (Entry key (Val mi ma cn to)) = | |
buildSLB key <> | |
BB.string8 | |
(printf "=%.1f/%.1f/%.1f" $$! | |
(fi mi / 10 :: Double) $$! | |
(fi to / (fi cn * 10) :: Double) $$! | |
(fi ma / 10 :: Double)) | |
go [] = mempty | |
go [e] = goEntry e | |
go (e:es) = goEntry e <> BB.string8 ", " <> go es | |
#endif | |
main :: IO () | |
main = | |
withFile "data/measurements500M.txt" \b -> do | |
num_threads <- getNumCapabilities | |
initTables (splitBuffer num_threads b) \bts -> do | |
Ptr tbl:ts <- mapConcurrently (\(b, Ptr t) -> Ptr t <$ parse t b) bts | |
forM_ ts \(Ptr tbl') -> | |
forTable tbl' \e -> | |
updateTable tbl e | |
#ifdef DISPLAY_OUTPUT | |
es <- sortBy (\e e' -> compare (_key e) (_key e')) <$> tableToList tbl | |
BB.hPutBuilder stdout (displayEntries es) | |
putChar '\n' | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment