Last active
February 2, 2022 04:03
-
-
Save qnkhuat/997d8da99e88f7967b8719982802af73 to your computer and use it in GitHub Desktop.
MNIST reader in clojure
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
;; Data download from: http://yann.lecun.com/exdb/mnist/ | |
(ns cnn.core | |
(:import java.io.File | |
java.io.FileInputStream)) | |
(defn u4->int [arr] | |
;; convert an array of 4 bytes to int | |
(loop [r 0 i 3 x arr] | |
(if (= i -1) | |
r | |
(recur | |
(bit-or | |
r | |
(bit-shift-left (first x) (* i 8))) | |
(dec i) | |
(rest x))))) | |
(defn to-unsinged | |
[b] | |
(bit-and b 0xff)) | |
(defn read-n-bytes | |
[is n] | |
(let [data (byte-array n)] | |
(.read is data) | |
;; clojure use signed-byte whereas mnist is unsigned | |
(mapv to-unsinged data))) | |
(defn read-images | |
[path] | |
(let [f (File. path) | |
is (FileInputStream. f) | |
magic-number (u4->int (read-n-bytes is 4)) | |
_ (when-not (= magic-number 2051) | |
(throw (ex-info "Magic number should be 2051 for image file" {}))) | |
n (u4->int (read-n-bytes is 4)) | |
rows (u4->int (read-n-bytes is 4)) | |
cols (u4->int (read-n-bytes is 4)) | |
data-len (* n rows cols) | |
data-arr (read-n-bytes is data-len)] | |
(when-not (= data-len (count data-arr)) | |
(throw (ex-info (format "Incorrect data length, Should be: %d, got: %d" data-len (count data-arr)) {}))) | |
(partition (* rows cols) data-arr))) | |
(defn read-labels | |
[path] | |
(let [f (File. path) | |
is (FileInputStream. f) | |
magic-number (u4->int (read-n-bytes is 4)) | |
_ (when-not (= magic-number 2049) | |
(throw (ex-info "Magic number should be 2049 for label file" {}))) | |
n (u4->int (read-n-bytes is 4)) | |
data-arr (read-n-bytes is n)] | |
(when-not (= n (count data-arr)) | |
(throw (ex-info (format "Incorerct data length, Should be: %d, got: %d" n (count data-arr)) {}))) | |
data-arr)) | |
;; Tested on Mac M1 Max 32GB Ram | |
(time (read-images "../mnist/train-images-idx3-ubyte")) | |
;; read an array of 47.040.000 elems | |
;; "Elapsed time: 1564.851583 msecs" (r/map without type hint) | |
;; "Elapsed time: 1475.194334 msecs" (r/map with type hint) | |
;; "Elapsed time: 1143.02050 msecs" (use mapv instead of r/map) | |
;; "Elapsed time: 6397.154542 msecs" (use map instead of mapv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment