Skip to content

Instantly share code, notes, and snippets.

@vyraun
Forked from codekansas/binarized_nn_inference.cpp
Created November 9, 2017 03:49
Show Gist options
  • Save vyraun/35481e4dbddd9a07f202e127d2ab7ab3 to your computer and use it in GitHub Desktop.
Save vyraun/35481e4dbddd9a07f202e127d2ab7ab3 to your computer and use it in GitHub Desktop.
Efficient binarized neural network inference
/* Binarized neural network inference example.
This shows a simple C++ program for doing inference on
binarized neural networks. To do this efficiently, the code
below makes use of the "bitset" class, which uses the "popcnt"
instruction to count the number of 1's that show up in the
matrix product, in constant time. This means that a matrix
multiplication between a (A, B) and (B, C) matrix takes
O(A * C) time; in other words, each value in the output matrix
is computed in constant time.
*/
#include <stdlib.h>
#include <stdexcept>
#include <bitset>
#include <vector>
#include <iostream>
using namespace std;
template <size_t W, size_t H>
class Matrix {
template <size_t Wf, size_t Hf>
friend class Matrix;
private:
bitset<W> data[H];
void check(size_t x, size_t y) {
if (x >= W || y >= H) {
string xs = to_string(x),
ys = to_string(y),
ws = to_string(W),
hs = to_string(H);
throw runtime_error(
"Invalid indices requested: (" +
xs + ", " + ys + ") for matrix with shape (" +
ws + ", " + hs + ")"
);
}
}
public:
Matrix() {
for (size_t i = 0; i < H; i++) {
data[i] = bitset<W>();
}
}
bool get(size_t x, size_t y) {
check(x, y);
return data[y][x];
}
void set(size_t x, size_t y, bool v) {
check(x, y);
data[y][x] = v;
}
void print(ostream &out, bool const transpose) {
const char ul = '.',
ur = '.',
ll = '*',
lr = '*';
// Prints the upper border.
out << ul;
for (size_t x = 0; x < (transpose ? H : W); x++) {
out << '-';
}
out << ur << endl;
// Prints the matrix data itself.
for (size_t y = 0; y < (transpose ? W : H); y++) {
cout << '|';
for (size_t x = 0; x < (transpose ? H : W); x++) {
bool is_set = transpose ? get(y, x) : get(x, y);
out << (is_set ? 'X' : ' ');
}
out << '|' << endl;
}
// Prints the lower border.
out << ll;
for (size_t x = 0; x < (transpose ? H : W); x++) {
out << '-';
}
out << lr << endl;
}
void print(ostream &out) { print(out, false); }
size_t width() const { return W; }
size_t height() const { return H; }
template <size_t N>
void dot_(Matrix<W, N> &rhs, Matrix<N, H> &out) {
for (size_t x = 0; x < N; x++) {
for (size_t y = 0; y < H; y++) {
out.set(x, y, (data[y] ^ rhs.data[x]).count() < (W / 2));
}
}
}
template <size_t N>
Matrix<N, H> dot(Matrix<W, N> &rhs) {
Matrix<N, H> *m = new Matrix<N, H>();
dot_(rhs, *m);
return *m;
}
};
// Demonstration with random arrays.
int main() {
srand(1337);
const size_t A = 3, B = 5, C = 8;
Matrix<A, B> m;
Matrix<A, C> n;
// Fills each matrix with random values.
for (int i = 0; i < A; i++) {
for (int j = 0; j < B; j++) {
m.set(i, j, rand() % 2);
}
for (int j = 0; j < C; j++) {
n.set(i, j, rand() % 2);
}
}
// Prints the "M" matrix.
cout << "M:" << endl;
m.print(cout);
// Prints the "N" matrix. Since we're dotting onto this
// matrix, we transpose it when printing.
cout << endl << "N:" << endl;
n.print(cout, true);
// Does the dot product.
Matrix<C, B> o = m.dot(n);
cout << endl << "O:" << endl;
o.print(cout);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment