Created
November 15, 2017 15:46
-
-
Save ear/1f459b108bd271d56597fd3fb0e580ca to your computer and use it in GitHub Desktop.
MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TypeApplications #-} | |
module Main where | |
import Prelude hiding (readFile) | |
import Data.Binary | |
import Data.Binary.Get | |
import qualified Data.ByteString as B | |
import qualified Data.ByteString.Lazy as BL | |
import System.IO.MMap | |
import Control.Monad | |
import Data.Maybe (listToMaybe) | |
import Data.List (unfoldr, intercalate) | |
data Labels = Labels Int Int [Float] | |
deriving (Show) | |
instance Binary Labels where | |
put = error "unimplemented" | |
get = do | |
magic <- fromIntegral <$> getWord32be | |
count <- fromIntegral <$> getWord32be | |
labels <- replicateM count $ fromIntegral <$> getWord8 | |
return $ Labels magic count labels | |
data Images = Images !Int !Int !Int !Int ![[Float]] | |
deriving (Show) | |
instance Binary Images where | |
put = error "unimplemented" | |
get = do | |
magic <- fromIntegral <$> getWord32be | |
count <- fromIntegral <$> getWord32be | |
rows <- fromIntegral <$> getWord32be | |
cols <- fromIntegral <$> getWord32be | |
images <- replicateM count (getImage rows cols) | |
return $ Images magic count rows cols images | |
where | |
getImage rows cols = do | |
pixels <- B.unpack <$> getByteString (rows*cols) | |
let image = fromIntegral @_ @Float <$> pixels | |
return image | |
main = do | |
ls <- decode @Labels <$> BL.readFile "train-labels-idx1-ubyte" | |
ms <- decode @Images <$> readFileViaMmap "train-images-idx3-ubyte" | |
let (Images _ _ rows cols images) = ms | |
mapM_ print $ map (Image rows cols) images | |
readFileViaMmap path = do | |
(ptr, _, _, size) <- mmapFilePtr path ReadOnly Nothing | |
BL.fromStrict <$> B.packCStringLen (ptr, size) | |
data Image = Image Int Int [Float] | |
instance Show Image where | |
show (Image rows cols pixels) = intercalate "\n" $ unfoldr showLine pixels | |
where | |
showLine pixels = | |
let (line,rest) = splitAt cols pixels | |
in listToMaybe rest >> return (map showPixel line,rest) | |
showPixel p | p < 256 / 10 = '░' | |
| p < 256 / 4 = '▒' | |
| p < 256 / 2 = '▓' | |
| p < 256 = '█' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment