Created
April 13, 2017 14:48
-
-
Save andrewthad/06437b8083009598f09f3391b711ef95 to your computer and use it in GitHub Desktop.
Indexing into a sorted vector
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
import Data.Bits | |
import Data.Vector (Vector) | |
import qualified Data.Vector.Unboxed as UV | |
import qualified Data.Vector as V | |
import qualified Data.List as L | |
import Test.QuickCheck.All (quickCheckAll) | |
data IntPair = IntPair | |
{-# UNPACK #-} !Int | |
{-# UNPACK #-} !Int | |
-- | This lookup is O(log n). It provides the first index | |
-- and the length of the matches from the sorted indexable | |
-- container. If there are no matches, the returned start | |
-- index is arbitrary and the returned run length is 0. | |
{-# INLINE lookupSorted #-} | |
lookupSorted :: Ord a => (Int -> a) -> Int -> a -> IntPair | |
lookupSorted lookupIx !len !needle = | |
let !(IntPair start end) = go 0 (len - 1) | |
in IntPair start (end - start + 1) | |
where | |
go :: Int -> Int -> IntPair | |
go !lo !hi = if lo <= hi | |
then | |
let !mid = lo + (unsafeShiftR (hi - lo) 1) | |
!val = lookupIx mid | |
in case compare val needle of | |
EQ -> IntPair (goLow lo (mid - 1)) (goHigh (mid + 1) hi) | |
LT -> go (mid + 1) hi | |
GT -> go lo (mid - 1) | |
else IntPair 0 (-1) | |
goLow :: Int -> Int -> Int | |
goLow !lo !hi = if lo <= hi | |
then | |
let !mid = lo + (unsafeShiftR (hi - lo) 1) | |
!val = lookupIx mid | |
in if val == needle | |
-- val will never be greater than needle, | |
-- the else statement handles then less than case | |
then goLow lo (mid - 1) | |
else goLow (mid + 1) hi | |
else lo | |
goHigh :: Int -> Int -> Int | |
goHigh !lo !hi = if lo <= hi | |
then | |
let !mid = lo + (unsafeShiftR (hi - lo) 1) | |
!val = lookupIx mid | |
in if val == needle | |
-- val will never be less than needle, | |
-- the else statement handles then greater than case | |
then goHigh (mid + 1) hi | |
else goHigh lo (mid - 1) | |
else hi | |
wordVectorMatchingSlice :: Word -> Vector Word -> Vector Word | |
wordVectorMatchingSlice w v = | |
let IntPair ix len = lookupSorted (v V.!) (V.length v) w | |
in V.slice ix len v | |
prop_sameAsList :: [Word] -> Word -> Bool | |
prop_sameAsList ws w = | |
let wsSorted = L.sort ws | |
v = V.fromList wsSorted | |
matchingElems = L.filter (== w) wsSorted | |
expectedVector = V.fromList matchingElems | |
actualVector = wordVectorMatchingSlice w v | |
in expectedVector == actualVector | |
return [] | |
runTests = $quickCheckAll | |
main = runTests |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment