Skip to content

Instantly share code, notes, and snippets.

@mratsim
Created May 10, 2018 21:28
Show Gist options
  • Save mratsim/3d6bf8870ea40906b77f8e6c0c310f43 to your computer and use it in GitHub Desktop.
Save mratsim/3d6bf8870ea40906b77f8e6c0c310f43 to your computer and use it in GitHub Desktop.
Optimized GRU Cell implementation in Arraymancer
import ../src/arraymancer
import macros
import math
func sigmoid*[T: SomeReal](x: T): T {.inline.} =
1 / (1 + exp(-x))
proc gru_cell_forward[T: SomeReal](input, hidden,
w_input, w_recur,
b_input, b_recur: Tensor[T],
result: var Tensor[T]) =
## Input:
## - input tensor of shape [batch_size, features]
## - hidden state of shape [batch_size, hidden_size]
## - weight of input [3 * hidden_size, features]
## - weight of hidden [3 * hidden_size, hidden_size]
## - biases of input and hidden state [3 * hidden_size]
# For compatibility with CuDNN and allow loading CPU/Cuda weights interchangeably,
# we use the following equations,
#
# - h is hidden state at t-1, h' at t
# - input == x, hidden == h
# - n = h~ (the candidate hidden state)
# - r is the reset gate
# - z is the update gate
# - h' is a linear interpolation
# - w_input == W, the concatenation of [Wr, Wz, Wn]
# - w_recur == R, the concatenation of [Rr, Rz, Rn]
# - bW and bR are the corresponding bias
# - R is called U in the original paper
#
# r = σ(Wr * x + bWr + Rr * h + bRr)
# z = σ(Wz * x + bWz + Rz * h + bRz)
# n = tanh(Wn * x + bWn + r .* (Rn * h + bRn))
# h' = (1 - z) .* n + z .* h
#
# Those differs from the original paper for n and h'
# - The pointwise multiplication by r is after the matrix multiplication
# - The linear interpolation has the terms switched
let
N = input.shape[0]
F = input.shape[1]
H = hidden.shape[1]
# Slices
sr = (0 ..< H)|1
sz = (H ..< 2*H)|1
srz = (0 ..< 2*H)|1
sn = (2*H ..< 3*H)|1
var Wx, Rh: Tensor[T]
linear(w_input, input, b_input, Wx)
linear(w_recur, hidden, b_recur, Rh)
# To reduce allocations, we compute reset gate r
# and update gate z in the previous buffers
# We keep them concatenated to improve throughput
var reset_update = Wx[_, srz]
apply2_inline(reset_update, Rh[_, srz]):
sigmoid(x + y)
# We also reuse the previous buffer for the candidate hidden state n
var n = Wx[_, sn]
apply3_inline(n, reset_update[_, sr], Rh[_, sn]):
tanh(x + y * z)
# Compute the next hidden state
result = map3_inline(Wx[_, sz], n, hidden):
(1 - x) * y + x * z
const
BatchSize = 10
Features = 4
HiddenSize = 7
let x = randomTensor[float32](BatchSize, Features, 1.0f)
let hidden = randomTensor[float32](BatchSize, HiddenSize, 1.0f)
let w_input = randomTensor[float32](3 * HiddenSize, Features, 1.0f)
let w_recur = randomTensor[float32](3 * HiddenSize, HiddenSize, 1.0f)
let b_input = randomTensor[float32](3 * HiddenSize, 1.0f)
let b_recur = randomTensor[float32](3 * HiddenSize, 1.0f)
var h_prime = zeros[float32](BatchSize, HiddenSize)
gru_cell_forward(x, hidden,
w_input, w_recur,
b_input, b_recur,
h_prime)
echo h_prime
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment