Created
May 10, 2018 21:28
-
-
Save mratsim/3d6bf8870ea40906b77f8e6c0c310f43 to your computer and use it in GitHub Desktop.
Optimized GRU Cell implementation in Arraymancer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ../src/arraymancer | |
import macros | |
import math | |
func sigmoid*[T: SomeReal](x: T): T {.inline.} = | |
1 / (1 + exp(-x)) | |
proc gru_cell_forward[T: SomeReal](input, hidden, | |
w_input, w_recur, | |
b_input, b_recur: Tensor[T], | |
result: var Tensor[T]) = | |
## Input: | |
## - input tensor of shape [batch_size, features] | |
## - hidden state of shape [batch_size, hidden_size] | |
## - weight of input [3 * hidden_size, features] | |
## - weight of hidden [3 * hidden_size, hidden_size] | |
## - biases of input and hidden state [3 * hidden_size] | |
# For compatibility with CuDNN and allow loading CPU/Cuda weights interchangeably, | |
# we use the following equations, | |
# | |
# - h is hidden state at t-1, h' at t | |
# - input == x, hidden == h | |
# - n = h~ (the candidate hidden state) | |
# - r is the reset gate | |
# - z is the update gate | |
# - h' is a linear interpolation | |
# - w_input == W, the concatenation of [Wr, Wz, Wn] | |
# - w_recur == R, the concatenation of [Rr, Rz, Rn] | |
# - bW and bR are the corresponding bias | |
# - R is called U in the original paper | |
# | |
# r = σ(Wr * x + bWr + Rr * h + bRr) | |
# z = σ(Wz * x + bWz + Rz * h + bRz) | |
# n = tanh(Wn * x + bWn + r .* (Rn * h + bRn)) | |
# h' = (1 - z) .* n + z .* h | |
# | |
# Those differs from the original paper for n and h' | |
# - The pointwise multiplication by r is after the matrix multiplication | |
# - The linear interpolation has the terms switched | |
let | |
N = input.shape[0] | |
F = input.shape[1] | |
H = hidden.shape[1] | |
# Slices | |
sr = (0 ..< H)|1 | |
sz = (H ..< 2*H)|1 | |
srz = (0 ..< 2*H)|1 | |
sn = (2*H ..< 3*H)|1 | |
var Wx, Rh: Tensor[T] | |
linear(w_input, input, b_input, Wx) | |
linear(w_recur, hidden, b_recur, Rh) | |
# To reduce allocations, we compute reset gate r | |
# and update gate z in the previous buffers | |
# We keep them concatenated to improve throughput | |
var reset_update = Wx[_, srz] | |
apply2_inline(reset_update, Rh[_, srz]): | |
sigmoid(x + y) | |
# We also reuse the previous buffer for the candidate hidden state n | |
var n = Wx[_, sn] | |
apply3_inline(n, reset_update[_, sr], Rh[_, sn]): | |
tanh(x + y * z) | |
# Compute the next hidden state | |
result = map3_inline(Wx[_, sz], n, hidden): | |
(1 - x) * y + x * z | |
const | |
BatchSize = 10 | |
Features = 4 | |
HiddenSize = 7 | |
let x = randomTensor[float32](BatchSize, Features, 1.0f) | |
let hidden = randomTensor[float32](BatchSize, HiddenSize, 1.0f) | |
let w_input = randomTensor[float32](3 * HiddenSize, Features, 1.0f) | |
let w_recur = randomTensor[float32](3 * HiddenSize, HiddenSize, 1.0f) | |
let b_input = randomTensor[float32](3 * HiddenSize, 1.0f) | |
let b_recur = randomTensor[float32](3 * HiddenSize, 1.0f) | |
var h_prime = zeros[float32](BatchSize, HiddenSize) | |
gru_cell_forward(x, hidden, | |
w_input, w_recur, | |
b_input, b_recur, | |
h_prime) | |
echo h_prime |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment