Created
July 9, 2018 12:35
-
-
Save ToruNiina/f7a3ba69585cf3bfd869e302357c11a8 to your computer and use it in GitHub Desktop.
branchless ReLU implementation (with bitwise operations)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// copyright Toru Niina 2018. distributed under the Boost Software License v1.0. | |
// it provides an implementation of ReLU without branching. | |
// the core idea of branchless-ReLU is the following. | |
// 1. first, bitwise-and with 0 bits everytime returns 0 bits. but bitwise-and | |
// with 1 bits returns the original argument kept intact. | |
// 0000 & 1010 == 0000, 1111 & 1010 == 1010 | |
// 2. second, we can make 0 bits or 1 bits depending on the sign bit by applying | |
// right arithmetic shift 31 times. | |
// 1000 >> 31 == 1111, 0110 >> 31 == 0000 | |
// 3. at last, by combining these two tricks, we can obtain ReLU function that | |
// returns 0 if the arg was negative, the original value if it was positive. | |
// | |
// the above idea can be extended to `double` and other floating point numbers | |
// because the only assumption used here is that floating point has sign bit. | |
// `ftoi` and `itof` functions and related unions converts | |
// float to int32_t and vice versa preserving bitwise representation. | |
typedef union {float f; int32_t i;} fandi; | |
int32_t ftoi(float f){fandi f2i; f2i.f = f; return f2i.i;} | |
float itof(int32_t i){fandi i2f; i2f.i = i; return i2f.f;} | |
float ReLU(float x) | |
{ | |
// s, e, f mean `sign bit`, `exponent`, and `fractional`, respectively. | |
// seeeffff | seeeffff | |
// 00010100 | 10010100 --+ | |
// >> 31 | >> 31 | | |
// -------------+------------- | | |
// not 00000000 | not 11111111 | | |
// -------------+------------- | | |
// 11111111 | 00000000 | | |
// and 00010100 | and 10010100 <-+ | |
// -------------+------------- | |
// 00010100 | 00000000 | |
// == x | 0.0 | |
const int32_t xi = ftoi(x); | |
return itof(xi & ~(xi >> 31)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment