Skip to content

Instantly share code, notes, and snippets.

@ear
Created November 15, 2017 15:46
Show Gist options
  • Save ear/1f459b108bd271d56597fd3fb0e580ca to your computer and use it in GitHub Desktop.
Save ear/1f459b108bd271d56597fd3fb0e580ca to your computer and use it in GitHub Desktop.
MNIST
{-# LANGUAGE TypeApplications #-}
module Main where
import Prelude hiding (readFile)
import Data.Binary
import Data.Binary.Get
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import System.IO.MMap
import Control.Monad
import Data.Maybe (listToMaybe)
import Data.List (unfoldr, intercalate)
data Labels = Labels Int Int [Float]
deriving (Show)
instance Binary Labels where
put = error "unimplemented"
get = do
magic <- fromIntegral <$> getWord32be
count <- fromIntegral <$> getWord32be
labels <- replicateM count $ fromIntegral <$> getWord8
return $ Labels magic count labels
data Images = Images !Int !Int !Int !Int ![[Float]]
deriving (Show)
instance Binary Images where
put = error "unimplemented"
get = do
magic <- fromIntegral <$> getWord32be
count <- fromIntegral <$> getWord32be
rows <- fromIntegral <$> getWord32be
cols <- fromIntegral <$> getWord32be
images <- replicateM count (getImage rows cols)
return $ Images magic count rows cols images
where
getImage rows cols = do
pixels <- B.unpack <$> getByteString (rows*cols)
let image = fromIntegral @_ @Float <$> pixels
return image
main = do
ls <- decode @Labels <$> BL.readFile "train-labels-idx1-ubyte"
ms <- decode @Images <$> readFileViaMmap "train-images-idx3-ubyte"
let (Images _ _ rows cols images) = ms
mapM_ print $ map (Image rows cols) images
readFileViaMmap path = do
(ptr, _, _, size) <- mmapFilePtr path ReadOnly Nothing
BL.fromStrict <$> B.packCStringLen (ptr, size)
data Image = Image Int Int [Float]
instance Show Image where
show (Image rows cols pixels) = intercalate "\n" $ unfoldr showLine pixels
where
showLine pixels =
let (line,rest) = splitAt cols pixels
in listToMaybe rest >> return (map showPixel line,rest)
showPixel p | p < 256 / 10 = '░'
| p < 256 / 4 = '▒'
| p < 256 / 2 = '▓'
| p < 256 = '█'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment